From 18f48dfc2cd42c77a4d5ac0c8e5c68d2a3d35a6d Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 6 Jan 2020 16:26:51 -0800 Subject: [PATCH 01/11] [REFACTOR][IR] Variable -> VarNode --- include/tvm/arithmetic.h | 12 ++++----- include/tvm/expr.h | 12 ++++----- include/tvm/ir.h | 2 +- include/tvm/operation.h | 14 +++++----- src/arithmetic/bound_deducer.cc | 16 ++++++------ src/arithmetic/const_int_bound.cc | 2 +- src/arithmetic/detect_linear_equation.cc | 10 +++---- src/arithmetic/domain_touched.cc | 6 ++--- src/arithmetic/int_set.cc | 12 ++++----- src/arithmetic/modular_set.cc | 2 +- src/arithmetic/rewrite_simplify.cc | 2 +- src/arithmetic/rewrite_simplify.h | 2 +- src/autotvm/touch_extractor.cc | 6 ++--- src/codegen/build_module.cc | 2 +- src/codegen/codegen_c.cc | 18 ++++++------- src/codegen/codegen_c.h | 18 ++++++------- src/codegen/codegen_cuda.cc | 10 +++---- src/codegen/codegen_cuda.h | 8 +++--- src/codegen/codegen_opencl.cc | 6 ++--- src/codegen/codegen_opencl.h | 10 +++---- src/codegen/codegen_opengl.cc | 6 ++--- src/codegen/codegen_opengl.h | 14 +++++----- src/codegen/codegen_source_base.cc | 8 +++--- src/codegen/codegen_source_base.h | 10 +++---- src/codegen/spirv/codegen_spirv.cc | 6 ++--- src/codegen/spirv/codegen_spirv.h | 6 ++--- src/codegen/stackvm/codegen_stackvm.cc | 6 ++--- src/codegen/stackvm/codegen_stackvm.h | 8 +++--- src/contrib/hybrid/codegen_hybrid.cc | 6 ++--- src/contrib/hybrid/codegen_hybrid.h | 6 ++--- src/lang/attr_functor.h | 4 +-- src/lang/data_layout.cc | 4 +-- src/lang/expr.cc | 6 ++--- src/op/compute_op.cc | 4 +-- src/op/extern_op.cc | 2 +- src/op/hybrid_op.cc | 32 +++++++++++------------ src/op/op_util.cc | 2 +- src/op/placeholder_op.cc | 2 +- src/op/scan_op.cc | 2 +- src/op/tensor_compute_op.cc | 4 +-- src/op/tensorize.cc | 10 +++---- src/pass/verify_gpu_code.cc | 8 +++--- src/relay/ir/alpha_equal.cc | 2 +- src/relay/ir/hash.cc | 4 +-- src/relay/pass/type_solver.cc | 4 +-- src/schedule/bound.cc | 4 +-- src/schedule/graph.cc | 2 +- src/schedule/message_passing.cc | 2 +- src/schedule/schedule_dataflow_rewrite.cc | 20 +++++++------- src/schedule/schedule_ops.cc | 4 +-- tests/cpp/ir_functor_test.cc | 6 ++--- topi/include/topi/detail/broadcast.h | 4 +-- 52 files changed, 189 insertions(+), 189 deletions(-) diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index e5f75673a9cb..d135d30a8fbb 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -564,7 +564,7 @@ IntSet EvalSet(Expr e, * \return An integer set that can cover all the possible values of e. */ IntSet EvalSet(Expr e, - const std::unordered_map& dom_map); + const std::unordered_map& dom_map); /*! * \brief Find an symbolic integer set that contains is union over @@ -586,7 +586,7 @@ IntSet EvalSet(Range r, * \return An integer set that can cover all the possible values. */ IntSet EvalSet(IntSet s, - const std::unordered_map& dom_map); + const std::unordered_map& dom_map); /*! * \brief Same as EvalSet, but takes unordered_map * @@ -595,7 +595,7 @@ IntSet EvalSet(IntSet s, * \return An integer set that can cover all the possible values of e. */ IntSet EvalSet(Range r, - const std::unordered_map& dom_map); + const std::unordered_map& dom_map); /*! \brief Map from Expr to IntSet */ using ExprIntSetMap = std::unordered_map; @@ -609,7 +609,7 @@ using ExprIntSetMap = std::unordered_map; */ ExprIntSetMap EvalSetForEachSubExpr( Expr e, - const std::unordered_map& dom_map); + const std::unordered_map& dom_map); /*! * \brief Create an union set of all sets @@ -654,8 +654,8 @@ IntSet DeduceBound(Expr v, Expr cond, * \return An integer set that always satisfies the condition. */ IntSet DeduceBound(Expr v, Expr cond, - const std::unordered_map& hint_map, - const std::unordered_map& relax_map); + const std::unordered_map& hint_map, + const std::unordered_map& relax_map); /*! * \brief Infer a regular domain that covers all the calls or provides within the given statement. diff --git a/include/tvm/expr.h b/include/tvm/expr.h index aee565dcbc9c..f78c3bc8ab77 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -102,7 +102,7 @@ class Var; * - Let * - LetStmt */ -class Variable : public ExprNode { +class VarNode : public ExprNode { public: /*! * \brief The hint to the variable name. @@ -118,7 +118,7 @@ class Variable : public ExprNode { } static constexpr const char* _type_key = "Variable"; - TVM_DECLARE_FINAL_OBJECT_INFO(Variable, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode); }; /*! \brief a named variable in TVM */ @@ -139,18 +139,18 @@ class Var : public Expr { * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const Variable* operator->() const { + const VarNode* operator->() const { return get(); } /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const Variable* get() const { - return static_cast(data_.get()); + const VarNode* get() const { + return static_cast(data_.get()); } /*! \brief type indicate the container type */ - using ContainerType = Variable; + using ContainerType = VarNode; }; // Backward compatibility, will be removed later. diff --git a/include/tvm/ir.h b/include/tvm/ir.h index b1cefff1e90e..5dd05beb7302 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -37,7 +37,7 @@ namespace tvm { namespace ir { using IntImm = tvm::IntImm; -using Variable = tvm::Variable; +using Variable = tvm::VarNode; /*! \brief constant unsigned integer. */ class UIntImm : public ExprNode { diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 681d06897355..ad8f8259c016 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -109,7 +109,7 @@ class OperationNode : public ir::FunctionBaseNode { virtual void PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const = 0; /*! * \brief Gather the bound from output tensor. @@ -173,7 +173,7 @@ class PlaceholderOpNode : public OperationNode { void PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; void GatherBound( const Operation& self, @@ -251,7 +251,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { void PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; Stmt BuildProvide( const Stage& stage, @@ -304,7 +304,7 @@ class TensorComputeOpNode : public BaseComputeOpNode { void PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; Stmt BuildProvide( const Stage& stage, @@ -379,7 +379,7 @@ class ScanOpNode : public OperationNode { void PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; void GatherBound( const Operation& self, @@ -446,7 +446,7 @@ class ExternOpNode : public OperationNode { void PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; void GatherBound( const Operation& self, @@ -514,7 +514,7 @@ class HybridOpNode : public OperationNode { void PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const final; void GatherBound( const Operation& self, diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index bb2e3400cc5b..b9c423c0ad47 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -78,8 +78,8 @@ class BoundDeducer: public ExprVisitor { friend class BoundDeduceInputChecker; friend class Converter; BoundDeducer(Expr target, Expr expr, - const std::unordered_map& hint_map, - const std::unordered_map& relax_map) + const std::unordered_map& hint_map, + const std::unordered_map& relax_map) : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {} void Deduce(); @@ -187,8 +187,8 @@ class BoundDeducer: public ExprVisitor { CompareOp ReverseOp(CompareOp comp_op); Expr target_; Expr expr_; - const std::unordered_map& hint_map_; - const std::unordered_map& relax_map_; + const std::unordered_map& hint_map_; + const std::unordered_map& relax_map_; ExprIntSetMap expr_map_; std::vector path_; size_t iter_{0}; @@ -330,8 +330,8 @@ void BoundDeducer::Relax() { } IntSet DeduceBound(Expr v, Expr e, - const std::unordered_map& hint_map, - const std::unordered_map& relax_map) { + const std::unordered_map& hint_map, + const std::unordered_map& relax_map) { BoundDeducer d(v, e, hint_map, relax_map); d.Deduce(); if (!d.success_) return IntSet::nothing(); @@ -352,11 +352,11 @@ IntSet DeduceBound(Expr v, Expr e, IntSet DeduceBound(Expr v, Expr e, const Map& hint_map, const Map& relax_map) { - std::unordered_map hmap; + std::unordered_map hmap; for (auto kv : hint_map) { hmap[kv.first.get()] = kv.second; } - std::unordered_map rmap; + std::unordered_map rmap; for (auto kv : relax_map) { rmap[kv.first.get()] = kv.second; } diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index ef405d8026c9..9a25c47f37a0 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -282,7 +282,7 @@ class ConstIntBoundAnalyzer::Impl : } } - Entry VisitExpr_(const Variable* op) final { + Entry VisitExpr_(const VarNode* op) final { Var v = GetRef(op); auto it = var_map_.find(v); if (it != var_map_.end()) { diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index b8ec974b436c..5c964d55bc9c 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -96,7 +96,7 @@ class LinearEqDetector ret.coeff = MulCombine(a.base, b.coeff); return ret; } - LinearEqEntry VisitExpr_(const Variable* op, const Expr& e) final { + LinearEqEntry VisitExpr_(const VarNode* op, const Expr& e) final { LinearEqEntry ret; if (op == var_.get()) { ret.coeff = make_const(op->dtype, 1); @@ -152,7 +152,7 @@ Array DetectLinearEquation(const Expr& e, const Array& vars) { base = std::move(ret.base); } - std::unordered_set vset; + std::unordered_set vset; for (size_t i = vars.size(); i > 1; --i) { vset.insert(vars[i - 1].get()); // The previous coeff contains the variable @@ -167,11 +167,11 @@ Array DetectLinearEquation(const Expr& e, const Array& vars) { // Detect clip condition as min max value bool DetectClipBound( const Expr& cond, - std::unordered_map* bmap) { + std::unordered_map* bmap) { int flag = 0; Var var; auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) { - if (const Variable* v = n.as()) { + if (const VarNode* v = n.as()) { if (bmap->count(v)) { if (flag == 0) { var = Downcast(n); @@ -244,7 +244,7 @@ void SplitCommExpr(const Expr& e, std::vector* ret) { Array DetectClipBound(const Expr& e, const Array& vars) { std::vector splits; SplitCommExpr(e, &splits); - std::unordered_map rmap; + std::unordered_map rmap; for (Var v : vars) { rmap[v.get()] = IntervalEntry(); } diff --git a/src/arithmetic/domain_touched.cc b/src/arithmetic/domain_touched.cc index 02f357837362..3fc82313e5b7 100644 --- a/src/arithmetic/domain_touched.cc +++ b/src/arithmetic/domain_touched.cc @@ -54,7 +54,7 @@ class FuncTouchedDomain final : public StmtExprVisitor { } void VisitStmt_(const For *op) final { - const Variable* var = op->loop_var.get(); + const VarNode* var = op->loop_var.get(); dom_map_[var] = IntSet::range( Range::make_by_min_extent(op->min, op->extent)); StmtExprVisitor::VisitStmt_(op); @@ -73,7 +73,7 @@ class FuncTouchedDomain final : public StmtExprVisitor { if (op->attr_key == attr::thread_extent) { const IterVarNode* thread_axis = op->node.as(); CHECK(thread_axis); - const Variable* var = thread_axis->var.get(); + const VarNode* var = thread_axis->var.get(); dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value)); StmtExprVisitor::VisitStmt_(op); dom_map_.erase(var); @@ -111,7 +111,7 @@ class FuncTouchedDomain final : public StmtExprVisitor { const Tensor &tensor_; bool consider_calls_, consider_provides_; std::vector > bounds_; - std::unordered_map dom_map_; + std::unordered_map dom_map_; }; Domain DomainTouched(Stmt stmt, const Tensor &tensor, bool consider_calls, bool consider_provides) { diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index bf1cdf0466b7..042d85e28751 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -388,7 +388,7 @@ class IntervalSetEvaluator : return IntervalSet::SinglePoint(GetRef(op)); } - IntervalSet VisitExpr_(const Variable* op) final { + IntervalSet VisitExpr_(const VarNode* op) final { Var var = GetRef(op); auto it = dom_map_.find(var); if (it != dom_map_.end()) { @@ -720,7 +720,7 @@ Map ConvertDomMap(const Map& dom_map) { } Map ConvertDomMap( - const std::unordered_map& dom_map) { + const std::unordered_map& dom_map) { Map dmap; for (auto kv : dom_map) { dmap.Set(GetRef(kv.first), kv.second); @@ -746,7 +746,7 @@ IntSet EvalSet(Expr e, } IntSet EvalSet(Expr e, - const std::unordered_map& dom_map) { + const std::unordered_map& dom_map) { return EvalSet(e, ConvertDomMap(dom_map)); } @@ -761,12 +761,12 @@ IntSet EvalSet(Range r, } IntSet EvalSet(Range r, - const std::unordered_map& dom_map) { + const std::unordered_map& dom_map) { return EvalSet(r, ConvertDomMap(dom_map)); } IntSet EvalSet(IntSet s, - const std::unordered_map& dom_map) { + const std::unordered_map& dom_map) { Analyzer ana; auto dmap = ConvertDomMap(dom_map); IntervalSetEvaluator m(&ana, dmap); @@ -796,7 +796,7 @@ class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { ExprIntSetMap EvalSetForEachSubExpr( Expr e, - const std::unordered_map& dom_map) { + const std::unordered_map& dom_map) { Analyzer ana; auto dmap = ConvertDomMap(dom_map); SubExprIntervalSetEvaluator m(&ana, dmap); diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index a83e98760baa..f40492325ca4 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -232,7 +232,7 @@ class ModularSetAnalyzer::Impl : } } - Entry VisitExpr_(const Variable* op) final { + Entry VisitExpr_(const VarNode* op) final { Var v = GetRef(op); auto it = var_map_.find(v); if (it != var_map_.end()) { diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index f883bf145f59..2c1fa5dc69f9 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -1728,7 +1728,7 @@ VisitExpr_(const Call* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Variable* op) { +VisitExpr_(const VarNode* op) { Var var = GetRef(op); auto it = var_map_.find(var); if (it != var_map_.end()) { diff --git a/src/arithmetic/rewrite_simplify.h b/src/arithmetic/rewrite_simplify.h index cf9dd6edbefa..4984bc524924 100644 --- a/src/arithmetic/rewrite_simplify.h +++ b/src/arithmetic/rewrite_simplify.h @@ -70,7 +70,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { Expr VisitExpr_(const Not* op) override; Expr VisitExpr_(const Select* op) override; Expr VisitExpr_(const Call* op) override; - Expr VisitExpr_(const Variable* op) override; + Expr VisitExpr_(const VarNode* op) override; Expr VisitExpr_(const Cast* op) override; Expr VisitExpr_(const Let* op) override; diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index 51b13549f296..c986ef7cbc2d 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -51,7 +51,7 @@ class IndexParser: public ExprVisitor { this->VisitExpr(expr); } - void VisitExpr_(const Variable* op) final { + void VisitExpr_(const VarNode* op) final { // TODO(lmzheng): handle more index types (multiple occurrence) if (pattern_map.count(op) == 0) { pattern_map[op] = TouchPattern(); @@ -61,7 +61,7 @@ class IndexParser: public ExprVisitor { } void VisitExpr_(const Mul* op) final { - if (op->a.as()) { + if (op->a.as()) { if (const auto stride = op->b.as()) { next_stride_ = stride->value; } @@ -69,7 +69,7 @@ class IndexParser: public ExprVisitor { ExprVisitor::VisitExpr_(op); } - std::unordered_map pattern_map; + std::unordered_map pattern_map; private: int64_t next_stride_ = 1; diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 38f0b9532133..86f47e75e08d 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -348,7 +348,7 @@ Buffer BufferWithOffsetAlignment(Array shape, bool has_any = false; if (!compact) { for (const auto& it : shape) { - if (it.as()) { + if (it.as()) { has_any = true; break; } diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index a3f145994f2c..bacfed02efaa 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -146,7 +146,7 @@ void CodeGenC::PrintSSAAssign( // Print a reference expression to a buffer. std::string CodeGenC::GetBufferRef( - DataType t, const Variable* buffer, Expr index) { + DataType t, const VarNode* buffer, Expr index) { std::ostringstream os; std::string vid = GetVarID(buffer); std::string scope; @@ -265,13 +265,13 @@ std::string CodeGenC::GetStructRef( } -bool CodeGenC::HandleTypeMatch(const Variable* buf_var, DataType t) const { +bool CodeGenC::HandleTypeMatch(const VarNode* buf_var, DataType t) const { auto it = handle_data_type_.find(buf_var); if (it == handle_data_type_.end()) return false; return it->second == t; } -void CodeGenC::RegisterHandleType(const Variable* buf_var, DataType t) { +void CodeGenC::RegisterHandleType(const VarNode* buf_var, DataType t) { auto it = handle_data_type_.find(buf_var); if (it == handle_data_type_.end()) { handle_data_type_[buf_var] = t; @@ -296,11 +296,11 @@ void CodeGenC::PrintVecElemStore(const std::string& vec, } std::string CodeGenC::GetVecLoad( - DataType t, const Variable* buffer, Expr base) { + DataType t, const VarNode* buffer, Expr base) { return GetBufferRef(t, buffer, base); } -void CodeGenC::PrintVecStore(const Variable* buffer, +void CodeGenC::PrintVecStore(const VarNode* buffer, DataType t, Expr base, const std::string& value) { std::string ref = GetBufferRef(t, buffer, base); @@ -462,7 +462,7 @@ void CodeGenC::VisitExpr_(const Cast* op, std::ostream& os) { // NOLINT(*) this->PrintExpr(op->value, value); os << CastFromTo(value.str(), op->value.dtype(), op->dtype); } -void CodeGenC::VisitExpr_(const Variable* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const VarNode* op, std::ostream& os) { // NOLINT(*) os << GetVarID(op); } void CodeGenC::VisitExpr_(const Add* op, std::ostream& os) { // NOLINT(*) @@ -791,7 +791,7 @@ void CodeGenC::VisitStmt_(const Allocate* op) { int32_t constant_size = op->constant_allocation_size(); CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; - const Variable* buffer = op->buffer_var.as(); + const VarNode* buffer = op->buffer_var.as(); std::string scope = alloc_storage_scope_.at(buffer); PrintStorageScope(scope, stream); stream << ' '; @@ -812,11 +812,11 @@ void CodeGenC::VisitStmt_(const AttrStmt* op) { } } } else if (op->attr_key == ir::attr::storage_scope) { - const Variable* v = op->node.as(); + const VarNode* v = op->node.as(); CHECK(v); alloc_storage_scope_[v] = op->value.as()->value; } else if (op->attr_key == ir::attr::volatile_scope) { - const Variable* v = op->node.as(); + const VarNode* v = op->node.as(); CHECK(v); volatile_buf_.insert(v); } diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index eae1e4961b77..b5555f25ace9 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -102,7 +102,7 @@ class CodeGenC : */ virtual void InitFuncState(LoweredFunc f); // expression - void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*) @@ -160,9 +160,9 @@ class CodeGenC : const std::string&op, DataType op_type, Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*) // print vector load - virtual std::string GetVecLoad(DataType t, const Variable* buffer, Expr base); + virtual std::string GetVecLoad(DataType t, const VarNode* buffer, Expr base); // print vector store - virtual void PrintVecStore(const Variable* buffer, + virtual void PrintVecStore(const VarNode* buffer, DataType t, Expr base, const std::string& value); // NOLINT(*) // print load of single element @@ -180,28 +180,28 @@ class CodeGenC : DataType t, const Expr& buffer, const Expr& index, int kind); // print reference to a buffer as type t in index. virtual std::string GetBufferRef( - DataType t, const Variable* buffer, Expr index); + DataType t, const VarNode* buffer, Expr index); /*! * \brief If buffer is allocated as type t. * \param buf_var The buffer variable. * \param t The type to be checked. */ - bool HandleTypeMatch(const Variable* buf_var, DataType t) const; + bool HandleTypeMatch(const VarNode* buf_var, DataType t) const; /*! * \brief Register the data type of buf_var * \param buf_var The buffer variable. * \param t The type to be checked. */ - void RegisterHandleType(const Variable* buf_var, DataType t); + void RegisterHandleType(const VarNode* buf_var, DataType t); // override void PrintSSAAssign( const std::string& target, const std::string& src, DataType t) final; /*! \brief restrict keyword */ std::string restrict_keyword_{""}; /*! \brief the storage scope of allocation */ - std::unordered_map alloc_storage_scope_; + std::unordered_map alloc_storage_scope_; /*! \brief the data type of allocated buffers */ - std::unordered_map handle_data_type_; + std::unordered_map handle_data_type_; /*! \brief reserves common C keywords */ void ReserveKeywordsAsUnique(); @@ -209,7 +209,7 @@ class CodeGenC : /*! \brief whether to print in SSA form */ bool print_ssa_form_{false}; /*! \brief set of volatile buf access */ - std::unordered_set volatile_buf_; + std::unordered_set volatile_buf_; }; } // namespace codegen diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 06b542a66323..5a4255b6d79d 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -371,11 +371,11 @@ void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) { void CodeGenCUDA::VisitStmt_(const AttrStmt* op) { if (op->attr_key == attr::fragment_shape) { - const Variable* buffer = op->node.as(); + const VarNode* buffer = op->node.as(); const StringImm* shape_str = op->value.as(); fragment_shapes[buffer] = shape_str->value; } else if (op->attr_key == attr::fragment_layout) { - const Variable* buffer = op->node.as(); + const VarNode* buffer = op->node.as(); const StringImm* layout_str = op->value.as(); fragment_layouts[buffer] = layout_str->value; } @@ -397,7 +397,7 @@ void CodeGenCUDA::VisitStmt_(const Allocate* op) { int32_t constant_size = op->constant_allocation_size(); CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; - const Variable* buffer = op->buffer_var.as(); + const VarNode* buffer = op->buffer_var.as(); std::string scope = alloc_storage_scope_.at(buffer); if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { @@ -528,7 +528,7 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(* } void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t, - const Variable* variable, std::ostream &os) { + const VarNode* variable, std::ostream &os) { std::stringstream type; PrintType(t, type); std::string shape_str = fragment_shapes[variable]; @@ -550,7 +550,7 @@ void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t, } int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope, - const Variable* variable, int32_t size) { + const VarNode* variable, int32_t size) { std::string shape_str = fragment_shapes[variable]; size_t m, n, k; size_t last_pos = 0, pos = 0; diff --git a/src/codegen/codegen_cuda.h b/src/codegen/codegen_cuda.h index 74d6fba35fc7..51d6e0c6b454 100644 --- a/src/codegen/codegen_cuda.h +++ b/src/codegen/codegen_cuda.h @@ -81,13 +81,13 @@ class CodeGenCUDA final : public CodeGenC { // whether need mma.h bool need_mma_h_{false}; - std::unordered_map fragment_shapes; - std::unordered_map fragment_layouts; + std::unordered_map fragment_shapes; + std::unordered_map fragment_layouts; friend void PrintConst(const FloatImm* op, std::ostream& os, CodeGenCUDA* p); void PrintWmmaScope( - const std::string& scope, DataType t, const Variable* variable, std::ostream& os); + const std::string& scope, DataType t, const VarNode* variable, std::ostream& os); int32_t GetWmmaFragmentSize( - const std::string &scope, const Variable* variable, int32_t size); + const std::string &scope, const VarNode* variable, int32_t size); }; } // namespace codegen diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index e466e28749ef..b273b5f75320 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -144,7 +144,7 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type"; } -void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, DataType t, +void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t, Expr base, std::ostream& os) { // NOLINT(*) if (!HandleTypeMatch(buffer, t.element_of())) { os << '('; @@ -160,7 +160,7 @@ void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, DataType t, PrintExpr(base, os); } std::string CodeGenOpenCL::GetVecLoad( - DataType t, const Variable* buffer, Expr base) { + DataType t, const VarNode* buffer, Expr base) { std::ostringstream os; os << "vload" << t.lanes() << "(0, "; PrintVecAddr(buffer, t, base, os); @@ -168,7 +168,7 @@ std::string CodeGenOpenCL::GetVecLoad( return os.str(); } -void CodeGenOpenCL::PrintVecStore(const Variable* buffer, +void CodeGenOpenCL::PrintVecStore(const VarNode* buffer, DataType t, Expr base, const std::string& value) { this->PrintIndent(); diff --git a/src/codegen/codegen_opencl.h b/src/codegen/codegen_opencl.h index 36324eb431ae..b1e979db492d 100644 --- a/src/codegen/codegen_opencl.h +++ b/src/codegen/codegen_opencl.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -44,13 +44,13 @@ class CodeGenOpenCL final : public CodeGenC { void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const Call* op) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - std::string GetVecLoad(DataType t, const Variable* buffer, + std::string GetVecLoad(DataType t, const VarNode* buffer, Expr base) final; - void PrintVecStore(const Variable* buffer, + void PrintVecStore(const VarNode* buffer, DataType t, Expr base, const std::string& value) final; // NOLINT(*) // the address of load/store - void PrintVecAddr(const Variable* buffer, DataType t, + void PrintVecAddr(const VarNode* buffer, DataType t, Expr base, std::ostream& os); // NOLINT(*) std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) diff --git a/src/codegen/codegen_opengl.cc b/src/codegen/codegen_opengl.cc index 29fcf85557b9..905b99a04e56 100644 --- a/src/codegen/codegen_opengl.cc +++ b/src/codegen/codegen_opengl.cc @@ -194,7 +194,7 @@ void CodeGenOpenGL::VisitStmt_(const Store* op) { } // texelFetch(tex, ivec2(idx & kTextureRowMask, idx >> kTextureRowBits), 0).r -std::string CodeGenOpenGL::TexelFetch(const Variable* buffer, Expr index) { +std::string CodeGenOpenGL::TexelFetch(const VarNode* buffer, Expr index) { std::ostringstream os; os << "texelFetch(" << GetVarID(buffer) << ", ivec2(int("; PrintExpr(index, os); @@ -207,7 +207,7 @@ std::string CodeGenOpenGL::TexelFetch(const Variable* buffer, Expr index) { // Print a reference expression to a buffer. // Format: texelFetch(buffer, index, 0).r std::string CodeGenOpenGL::GetBufferRef( - DataType t, const Variable* buffer, Expr index) { + DataType t, const VarNode* buffer, Expr index) { CHECK_EQ(t.lanes(), 1) << "Vector type not supported."; CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported."; @@ -269,7 +269,7 @@ void CodeGenOpenGL::VisitStmt_(const Evaluate* op) { } CHECK_EQ(call->args.size(), 2); - auto buffer = call->args[0].as(); + auto buffer = call->args[0].as(); auto value = call->args[1]; // Doesn't support store to vector. diff --git a/src/codegen/codegen_opengl.h b/src/codegen/codegen_opengl.h index 46e87a8165c1..b3536edf4865 100644 --- a/src/codegen/codegen_opengl.h +++ b/src/codegen/codegen_opengl.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -44,8 +44,8 @@ class CodeGenOpenGL final : public CodeGenC { void InitFuncState(LoweredFunc f) final; void BindThreadIndex(const IterVar& iv) final; void VisitStmt_(const Store* op) final; - std::string TexelFetch(const Variable* buffer, Expr index); - std::string GetBufferRef(DataType t, const Variable* buffer, Expr index) final; + std::string TexelFetch(const VarNode* buffer, Expr index); + std::string GetBufferRef(DataType t, const VarNode* buffer, Expr index) final; void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) // Codegen for immediate values @@ -58,9 +58,9 @@ class CodeGenOpenGL final : public CodeGenC { void VisitStmt_(const Evaluate* op) final; // NOLINT(*) private: - const Variable* output_{nullptr}; - std::unordered_set inputs_; - const Variable* output_iter_var_{nullptr}; + const VarNode* output_{nullptr}; + std::unordered_set inputs_; + const VarNode* output_iter_var_{nullptr}; std::unordered_map shaders_; std::string thread_extent_var_; }; diff --git a/src/codegen/codegen_source_base.cc b/src/codegen/codegen_source_base.cc index 7c4ed5b91c8b..aa3b6ef68fd5 100644 --- a/src/codegen/codegen_source_base.cc +++ b/src/codegen/codegen_source_base.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -69,7 +69,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) { return e.vid; } -std::string CodeGenSourceBase::AllocVarID(const Variable* v) { +std::string CodeGenSourceBase::AllocVarID(const VarNode* v) { CHECK(!var_idmap_.count(v)) << "Need input to be in SSA form dup " << v->name_hint; std::string key = v->name_hint; @@ -78,7 +78,7 @@ std::string CodeGenSourceBase::AllocVarID(const Variable* v) { return vid; } -std::string CodeGenSourceBase::GetVarID(const Variable* v) const { +std::string CodeGenSourceBase::GetVarID(const VarNode* v) const { auto it = var_idmap_.find(v); CHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint; diff --git a/src/codegen/codegen_source_base.h b/src/codegen/codegen_source_base.h index 7fd0eef98a90..b39ee46b0a17 100644 --- a/src/codegen/codegen_source_base.h +++ b/src/codegen/codegen_source_base.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -66,13 +66,13 @@ class CodeGenSourceBase { * \param v The variable. * \return the variable name. */ - std::string AllocVarID(const Variable* v); + std::string AllocVarID(const VarNode* v); /*! * \brief Get a variable name. * \param v The variable. * \return the variable name. */ - std::string GetVarID(const Variable* v) const; + std::string GetVarID(const VarNode* v) const; /*! * \brief Get the SSA ID corresponds to src * If necessary, generate new assignment @@ -110,7 +110,7 @@ class CodeGenSourceBase { /*! \brief the stream to be printed */ std::ostringstream stream; /*! \brief name of each variable */ - std::unordered_map var_idmap_; + std::unordered_map var_idmap_; private: /*! \brief assignment map of ssa */ diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index 0709965d0e8b..72709ebf9def 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -126,7 +126,7 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const Call* op) { return value; } -spirv::Value CodeGenSPIRV::VisitExpr_(const Variable* op) { +spirv::Value CodeGenSPIRV::VisitExpr_(const VarNode* op) { auto it = var_map_.find(op); CHECK(it != var_map_.end()) << "cannot find variable " << op->name_hint; return it->second; @@ -613,12 +613,12 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) { } } } else if (op->attr_key == ir::attr::storage_scope) { - const Variable* v = op->node.as(); + const VarNode* v = op->node.as(); CHECK(v); storage_info_[v].scope = runtime::StorageScope::make(op->value.as()->value); } else if (op->attr_key == ir::attr::volatile_scope) { - const Variable* v = op->node.as(); + const VarNode* v = op->node.as(); CHECK(v); storage_info_[v].is_volatile = true; } diff --git a/src/codegen/spirv/codegen_spirv.h b/src/codegen/spirv/codegen_spirv.h index 5cd88c9f267a..c6833057418d 100644 --- a/src/codegen/spirv/codegen_spirv.h +++ b/src/codegen/spirv/codegen_spirv.h @@ -62,7 +62,7 @@ class CodeGenSPIRV: return VisitExpr(e); } // override codegen - spirv::Value VisitExpr_(const Variable* op) override; + spirv::Value VisitExpr_(const VarNode* op) override; spirv::Value VisitExpr_(const Cast* op) override; spirv::Value VisitExpr_(const IntImm* op) override; spirv::Value VisitExpr_(const UIntImm* op) override; @@ -139,9 +139,9 @@ class CodeGenSPIRV: // Likely branch uint32_t weight_likely_branch_{128}; // the storage scope of allocation - std::unordered_map storage_info_; + std::unordered_map storage_info_; // The definition of local variable. - std::unordered_map var_map_; + std::unordered_map var_map_; // The analyzer. std::unique_ptr analyzer_; }; diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index 23bb008a0e7e..ce0f45b6ec10 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -81,7 +81,7 @@ int CodeGenStackVM::GetStrID(const std::string& key) { return sid; } -int CodeGenStackVM::AllocVarID(const Variable* v) { +int CodeGenStackVM::AllocVarID(const VarNode* v) { CHECK(!var_idmap_.count(v)); int vid = static_cast(vm_.heap_size); CHECK_EQ(vm_.heap_size, var_idmap_.size()); @@ -91,7 +91,7 @@ int CodeGenStackVM::AllocVarID(const Variable* v) { return vid; } -int CodeGenStackVM::GetVarID(const Variable* v) const { +int CodeGenStackVM::GetVarID(const VarNode* v) const { auto it = var_idmap_.find(v); CHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint; @@ -290,7 +290,7 @@ void CodeGenStackVM::VisitExpr_(const FloatImm* op) { LOG(FATAL) << "Float Imm is not supported"; } -void CodeGenStackVM::VisitExpr_(const Variable* op) { +void CodeGenStackVM::VisitExpr_(const VarNode* op) { int vid = this->GetVarID(op); this->PushOp(StackVM::LOAD_HEAP, vid); } diff --git a/src/codegen/stackvm/codegen_stackvm.h b/src/codegen/stackvm/codegen_stackvm.h index 7a4c0ab797fd..36287f783c58 100644 --- a/src/codegen/stackvm/codegen_stackvm.h +++ b/src/codegen/stackvm/codegen_stackvm.h @@ -96,13 +96,13 @@ class CodeGenStackVM * \param v The variable. * \return the heap index of the var. */ - int AllocVarID(const Variable* v); + int AllocVarID(const VarNode* v); /*! * \brief Get a variable name. * \param v The variable. * \return the heap index of the var. */ - int GetVarID(const Variable* v) const; + int GetVarID(const VarNode* v) const; // Push binary operator void PushBinary(StackVM::OpCode op_int64, const Expr& a, @@ -111,7 +111,7 @@ class CodeGenStackVM void PushCast(DataType dst, DataType src); // overloadable functions // expression - void VisitExpr_(const Variable* op) final; + void VisitExpr_(const VarNode* op) final; void VisitExpr_(const Load* op) final; void VisitExpr_(const Let* op) final; void VisitExpr_(const Call* op) final; @@ -156,7 +156,7 @@ class CodeGenStackVM /*! \brief The vm to be generated */ StackVM vm_; /*! \brief id of each variable */ - std::unordered_map var_idmap_; + std::unordered_map var_idmap_; /*! \brief id of each string */ std::unordered_map str_idmap_; /*! \brief id of each global function */ diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 00b2c230c5bb..ea2bf4d92112 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -138,7 +138,7 @@ void CodeGenHybrid::VisitExpr_(const Cast* op, std::ostream& os) { // NOLINT(*) } } -void CodeGenHybrid::VisitExpr_(const Variable* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const VarNode* op, std::ostream& os) { // NOLINT(*) os << GetVarID(op); } void CodeGenHybrid::VisitExpr_(const Add* op, std::ostream& os) { // NOLINT(*) @@ -410,7 +410,7 @@ void CodeGenHybrid::PrintIndent() { stream << std::string(indent_, ' '); } -std::string CodeGenHybrid::GetVarID(const Variable *v) { +std::string CodeGenHybrid::GetVarID(const VarNode *v) { if (binds_.count(v)) return binds_[v]; auto key = std::make_pair(static_cast(v), 0); @@ -489,7 +489,7 @@ void CodeGenHybrid::DumpStmt(const Stmt &stmt, if (auto tensor = inputs[i].as()) { stream << GetTensorID(tensor->op, tensor->value_index); } else { - auto var = inputs[i].as(); + auto var = inputs[i].as(); CHECK(var) << "Input should either be a tensor or a variable!"; stream << GetVarID(var); } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 27c97c73e333..a43b98aa6174 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -90,7 +90,7 @@ class CodeGenHybrid : return os.str(); } // expression - void VisitExpr_(const Variable* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*) @@ -154,7 +154,7 @@ class CodeGenHybrid : * Values are the corresponding IDs.*/ std::map, std::string> id_map_; /*! \brief Variables (keys) binded to the threads (values). */ - std::map binds_; + std::map binds_; /*! * \brief Find an unallocated name for the given prefix. * \param prefix The given prefix. @@ -166,7 +166,7 @@ class CodeGenHybrid : * \brief Get or allocate the ID for the given variable. * \param v The given variable. */ - std::string GetVarID(const Variable *v); + std::string GetVarID(const VarNode *v); /*! * \brief Get or allocate the ID for the given tensor. * \param func The tensor to allocate a name. diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index 51b355e81df3..49ab2fd2a2ba 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -81,7 +81,7 @@ class AttrFunctor { virtual R VisitAttr_(const ir::FloatImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; // deep comparison of symbolic integer expressions. - virtual R VisitAttr_(const Variable* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Add* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Sub* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT; @@ -116,7 +116,7 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(UIntImm); ATTR_FUNCTOR_DISPATCH(FloatImm); ATTR_FUNCTOR_DISPATCH(StringImm); - ATTR_FUNCTOR_DISPATCH(Variable); + ATTR_FUNCTOR_DISPATCH(VarNode); ATTR_FUNCTOR_DISPATCH(Add); ATTR_FUNCTOR_DISPATCH(Sub); ATTR_FUNCTOR_DISPATCH(Mul); diff --git a/src/lang/data_layout.cc b/src/lang/data_layout.cc index c4a6b35c0724..95f59c802465 100644 --- a/src/lang/data_layout.cc +++ b/src/lang/data_layout.cc @@ -251,7 +251,7 @@ inline Array TransformIndex(const Array& src_index, const Array& src_axis, const Array& transform_rule) { Array result; - std::unordered_map bind_map; + std::unordered_map bind_map; for (size_t i = 0; i < src_index.size(); ++i) { bind_map[src_axis[i]->var.get()] = src_index[i]; } @@ -287,7 +287,7 @@ inline Array TransformShape(const Array& src_shape, // for major-axis, bind the corresponding size // for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule, // e.g., (C * 16 + c) / 32 - std::unordered_map bind_map; + std::unordered_map bind_map; std::unordered_set symbolic_var_set; for (size_t i = 0; i < src_shape.size(); ++i) { Expr orig_shape = src_shape[i]; diff --git a/src/lang/expr.cc b/src/lang/expr.cc index eed693808708..eca44d80f186 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -39,10 +39,10 @@ Expr::Expr(std::string str) : Expr(ir::StringImm::make(str)) {} Var::Var(std::string name_hint, DataType t) - : Var(Variable::make(t, name_hint)) {} + : Var(VarNode::make(t, name_hint)) {} -Var Variable::make(DataType t, std::string name_hint) { - ObjectPtr node = make_object(); +Var VarNode::make(DataType t, std::string name_hint) { + ObjectPtr node = make_object(); node->dtype = t; node->name_hint = std::move(name_hint); return Var(node); diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 6146284554b4..5cfecb6628d0 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -212,7 +212,7 @@ Operation ComputeOpNode::ReplaceInputs( void ComputeOpNode::PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) { @@ -607,7 +607,7 @@ Stmt TransformUpdate(const Stage& stage, Stmt body, Stmt update) { Array conds; - std::unordered_set banned; + std::unordered_set banned; for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { IterVar iv = stage->leaf_iter_vars[i]; auto iit = stage->iter_var_attrs.find(iv); diff --git a/src/op/extern_op.cc b/src/op/extern_op.cc index c6102ed556e0..4559e9c8923b 100644 --- a/src/op/extern_op.cc +++ b/src/op/extern_op.cc @@ -113,7 +113,7 @@ Operation ExternOpNode::ReplaceInputs( void ExternOpNode::PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const { for (Tensor t : this->inputs) { auto it = out_dom_map->find(t); diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc index b4f29f5a36c8..0b996f1780ce 100644 --- a/src/op/hybrid_op.cc +++ b/src/op/hybrid_op.cc @@ -128,7 +128,7 @@ Operation HybridOpNode::ReplaceInputs( void HybridOpNode::PropBoundToInputs( const Operation &self, arith::Analyzer* analyzer, - const std::unordered_map &dom_map, + const std::unordered_map &dom_map, std::unordered_map* out_dom_map) const { auto curr_inputs = InputTensors(); for (Tensor t : curr_inputs) { @@ -223,7 +223,7 @@ Stmt ApplyLoopShapes(const Stage &stage, const std::unordered_map &dom_map, Stmt stmt) { class LoopSpliter : public StmtExprMutator { Expr factor; - const Variable *parent; + const VarNode *parent; IterVar inner, outer; public: @@ -249,7 +249,7 @@ Stmt ApplyLoopShapes(const Stage &stage, Stmt VisitStmt_(const For *op) final { if (op->loop_var.get() == parent) { - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = inner + outer * factor; Stmt ret = ir::Substitute(op->body, rmap); Expr cond = likely(outer * factor < (op->extent - inner)); @@ -267,8 +267,8 @@ Stmt ApplyLoopShapes(const Stage &stage, class LoopFuser : public StmtExprMutator { const IterVar &parent; - const Variable *inner; - const Variable *outer; + const VarNode *inner; + const VarNode *outer; bool under_outer; Expr extent; @@ -283,7 +283,7 @@ Stmt ApplyLoopShapes(const Stage &stage, Stmt VisitStmt_(const For* op) final { if (op->loop_var.get() == inner) { CHECK(under_outer); - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = indexmod(parent, op->extent); extent = op->extent; fused = true; @@ -291,7 +291,7 @@ Stmt ApplyLoopShapes(const Stage &stage, } else if (op->loop_var.get() == outer) { under_outer = true; Stmt body = this->VisitStmt(op->body); - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = indexdiv(parent, extent); body = ir::Substitute(body, rmap); under_outer = false; @@ -299,7 +299,7 @@ Stmt ApplyLoopShapes(const Stage &stage, op->for_type, op->device_api, body); } else if (under_outer) { Stmt body = this->VisitStmt(op->body); - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent); body = ir::Substitute(body, rmap); extent = extent * op->extent; @@ -327,11 +327,11 @@ Stmt ApplyLoopShapes(const Stage &stage, Stmt ApplyLoopAnnotations(const Stage &stage, const std::unordered_map &rebased, Stmt stmt) { class LoopAnnotator : public StmtMutator { - const Variable *var; + const VarNode *var; const IterVarAttr &attr; public: - LoopAnnotator(const Variable *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {} + LoopAnnotator(const VarNode *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {} Stmt VisitStmt_(const For *op) final { if (op->loop_var.get() == var) { @@ -342,7 +342,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage, CHECK(Equal(iter_var->dom->extent, op->extent)) << "Thread extent and loop extent mismatch!\n"; } - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = iter_var; Stmt body = ir::Substitute(op->body, rmap); return AttrStmt::make(iter_var, "thread_extent", op->extent, body); @@ -360,7 +360,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage, int found = 0; const IterVar &actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; - const Variable *var = actual->var.get(); + const VarNode *var = actual->var.get(); ForType expected = IterVarTypeToForType(iter_var->iter_type); IterVarAttr attr; if (stage->iter_var_attrs.count(iter_var)) { @@ -389,7 +389,7 @@ Stmt ApplyLoopAnnotations(const Stage &stage, Stmt ApplyLoopOrder(const Stage &stage, const std::unordered_map &dom_map, const std::unordered_map &rebased, Stmt stmt) { - std::vector current_order; + std::vector current_order; PostOrderVisit(stmt, [¤t_order](const ObjectRef& node) { if (const For *op = node.as()) current_order.push_back(op->loop_var.get()); @@ -397,7 +397,7 @@ Stmt ApplyLoopOrder(const Stage &stage, std::reverse(current_order.begin(), current_order.end()); auto &required_ord = stage->leaf_iter_vars; CHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!"; - std::unordered_map reorder; + std::unordered_map reorder; bool need_reorder = false; for (size_t i = 0; i < current_order.size(); ++i) { auto ¤t = current_order[i]; @@ -413,12 +413,12 @@ Stmt ApplyLoopOrder(const Stage &stage, class LoopReorder : public StmtMutator { const Stage &stage; const std::unordered_map &dom_map; - const std::unordered_map &reorder; + const std::unordered_map &reorder; public: LoopReorder(const Stage &stage, const std::unordered_map &dom_map, - const std::unordered_map &reorder) + const std::unordered_map &reorder) : stage(stage), dom_map(dom_map), reorder(reorder) {} Stmt VisitStmt_(const For* op) final { diff --git a/src/op/op_util.cc b/src/op/op_util.cc index 4a6d0d2f302a..789b8492b974 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -229,7 +229,7 @@ Expr ReplaceTensor(Expr expr, Stmt Substitute(Stmt s, const std::unordered_map& value_map) { - std::unordered_map init; + std::unordered_map init; for (const auto& kv : value_map) { init[kv.first->var.get()] = kv.second; } diff --git a/src/op/placeholder_op.cc b/src/op/placeholder_op.cc index 6414d5c39ac1..2ec10caf07a9 100644 --- a/src/op/placeholder_op.cc +++ b/src/op/placeholder_op.cc @@ -79,7 +79,7 @@ Operation PlaceholderOpNode::ReplaceInputs( void PlaceholderOpNode::PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const { } diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc index ef2d1efcc089..7e7e04911627 100644 --- a/src/op/scan_op.cc +++ b/src/op/scan_op.cc @@ -177,7 +177,7 @@ Operation ScanOpNode::ReplaceInputs( void ScanOpNode::PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); for (size_t i = 0, sp_idx = 0; i < this->init.size(); ++i) { diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc index a6252df05246..209dfe89accc 100644 --- a/src/op/tensor_compute_op.cc +++ b/src/op/tensor_compute_op.cc @@ -109,7 +109,7 @@ Operation TensorComputeOpNode::ReplaceInputs( void TensorComputeOpNode::PropBoundToInputs( const Operation& self, arith::Analyzer* analyzer, - const std::unordered_map& dom_map, + const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const { for (size_t i = 0; i < this->inputs.size(); ++i) { Tensor t = this->inputs[i]; @@ -182,7 +182,7 @@ Stmt TensorComputeOpNode::BuildProvide( } // Check variable remap - std::unordered_map vmap; + std::unordered_map vmap; ir::ArgBinder binder(&vmap); // Map the expressions passed in the call to the TensorIntrin, to the placeholder diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index 0df8e889efeb..70cb689cff91 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -85,7 +85,7 @@ size_t InferTensorizeRegion( schedule::PassUpDomain(stage, dom_map, &up_state); // Get domains if inputs std::unordered_map in_dom; - std::unordered_map temp_dmap; + std::unordered_map temp_dmap; arith::Analyzer analyzer; Array inputs = self->InputTensors(); for (Tensor t : inputs) { @@ -119,7 +119,7 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, const ComputeLoopNest& n, size_t tloc) { // Veirfication step. - std::unordered_set banned; + std::unordered_set banned; CHECK_EQ(n.main_nest.size(), stage->leaf_iter_vars.size() + 1); CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 || n.init_nest.size() == 0); @@ -182,7 +182,7 @@ class TensorIntrinMatcher final : public StmtExprMutator { return expr; } - Expr VisitExpr_(const Variable* op) final { + Expr VisitExpr_(const VarNode* op) final { auto it = var_remap_.find(op); if (it != var_remap_.end()) { return it->second; @@ -301,7 +301,7 @@ class TensorIntrinMatcher final : public StmtExprMutator { // input data remap std::unordered_map in_remap_; // variable remap. - std::unordered_map var_remap_; + std::unordered_map var_remap_; // IterVar remap. std::unordered_map axis_remap_; }; @@ -415,7 +415,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, Call::make(DataType::Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); } // Check variable remap - std::unordered_map vmap; + std::unordered_map vmap; ir::ArgBinder binder(&vmap); CHECK_GE(self->reduce_axis.size(), intrin_compute->reduce_axis.size()) << "Tensorization fail: reduction axis size do not match"; diff --git a/src/pass/verify_gpu_code.cc b/src/pass/verify_gpu_code.cc index 08ec41334cb0..3be2821ad791 100644 --- a/src/pass/verify_gpu_code.cc +++ b/src/pass/verify_gpu_code.cc @@ -95,9 +95,9 @@ class GPUCodeVerifier : public StmtVisitor { if (op->attr_key == attr::storage_scope) { std::string op_value = op->value.as()->value; if (op_value == "local") { - visited_local_buffers_.insert(op->node.as()); + visited_local_buffers_.insert(op->node.as()); } else if (op_value == "shared") { - visited_shared_buffers_.insert(op->node.as()); + visited_shared_buffers_.insert(op->node.as()); } } else if (op->attr_key == attr::thread_extent) { VarExpr var = op->node.as()->var; @@ -140,8 +140,8 @@ class GPUCodeVerifier : public StmtVisitor { private: int nest_level_{0}; - std::unordered_set visited_local_buffers_; - std::unordered_set visited_shared_buffers_; + std::unordered_set visited_local_buffers_; + std::unordered_set visited_shared_buffers_; std::unordered_set visited_threads_; size_t thread_x_extent_, thread_y_extent_, thread_z_extent_; diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 248f06bbe508..b41d381dd827 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -196,7 +196,7 @@ class AlphaEqualHandler: } } using AttrsEqualHandler::VisitAttr_; - bool VisitAttr_(const Variable* lhs, const ObjectRef& other) final { + bool VisitAttr_(const tvm::VarNode* lhs, const ObjectRef& other) final { return LeafObjectEqual(GetRef(lhs), other); } diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index cf1e280d6b9e..d179d7ebb849 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -125,8 +125,8 @@ class RelayHashHandler: } using AttrsHashHandler::VisitAttr_; - size_t VisitAttr_(const Variable* var) final { - size_t hash = std::hash()(Variable::_type_key); + size_t VisitAttr_(const tvm::VarNode* var) final { + size_t hash = std::hash()(VarNode::_type_key); auto it = hash_map_.find(GetRef(var)); if (it != hash_map_.end()) { return it->second; diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 221f2c1b2cad..b800801a7f5d 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -182,7 +182,7 @@ class TypeSolver::Unifier : public TypeFunctor { return Any::make(); } - auto left_index0 = ulhs.as(); + auto left_index0 = ulhs.as(); auto right_index0 = urhs.as(); if (left_index0 && right_index0) { solver_->shape_uf_.Set(ulhs, urhs); @@ -190,7 +190,7 @@ class TypeSolver::Unifier : public TypeFunctor { } auto left_index1 = ulhs.as(); - auto right_index1 = urhs.as(); + auto right_index1 = urhs.as(); if (left_index1 && right_index1) { solver_->shape_uf_.Set(urhs, ulhs); return ulhs; diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index 7cf5cff0aff7..ce2397b1d4f7 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -137,7 +137,7 @@ void InferRootBound(const Stage& stage, Array stage_attach = ctx.attach_path.at(stage->op); // The parent set. for (const Operation& op : consumers) { - std::unordered_map relax_set; + std::unordered_map relax_set; std::unordered_map up_state; bool found_attach = false; CHECK(ctx.op2stage_.count(op.get())); @@ -188,7 +188,7 @@ void InferRootBound(const Stage& stage, // Get the domain of the consumer PassUpDomain(op_stage, *rmap, &up_state); // Relax if needed. - std::unordered_map dom_map; + std::unordered_map dom_map; arith::Analyzer analyzer; for (auto iv : op->root_iter_vars()) { Range r; diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index a5ed43601024..56370241691d 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -269,7 +269,7 @@ ReachGraph GetReachGraph(const Array& ops) { for (size_t i = 0; i < call->args.size(); ++i) { TensorDimKey dkey(call, static_cast(i)); auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) { - const Variable *v = node.as(); + const VarNode *v = node.as(); auto it = vmap.find(v); if (it != vmap.end()) { reach[it->second].push_back(dkey); diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc index f917e7fe6387..d08b4bec5849 100644 --- a/src/schedule/message_passing.cc +++ b/src/schedule/message_passing.cc @@ -501,7 +501,7 @@ std::vector MakeBoundCheck( PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer); std::vector preds; - std::unordered_map iset_dmap; + std::unordered_map iset_dmap; // setup domain map for set analysis for (const auto& kv : dom_map) { diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 9aef563fbefc..257ff320d623 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -45,9 +45,9 @@ size_t FindNodeRef(ArrayNode* array_node, const T& v) { class VarReplacer : public ir::StmtExprMutator { public: explicit VarReplacer( - const std::unordered_map& vsub) + const std::unordered_map& vsub) : vsub_(vsub) {} - Expr VisitExpr_(const Variable* op) final { + Expr VisitExpr_(const VarNode* op) final { auto it = vsub_.find(op); if (it != vsub_.end()) return it->second; return GetRef(op); @@ -88,7 +88,7 @@ class VarReplacer : public ir::StmtExprMutator { } private: - const std::unordered_map& vsub_; + const std::unordered_map& vsub_; }; Expr InjectPredicate(const Array& predicates, @@ -193,8 +193,8 @@ void PrepareAxisMapping(Stage orig_stage, std::unordered_set* p_red_axis, Array* p_new_axis, std::unordered_map* p_dom_map, - std::unordered_map* p_vsub, - std::unordered_map* p_vsub2newvar, + std::unordered_map* p_vsub, + std::unordered_map* p_vsub2newvar, std::vector* p_predicates) { auto& red_axis = *p_red_axis; auto& new_axis = *p_new_axis; @@ -305,8 +305,8 @@ Array CacheWriteWithReLayout(Schedule sch, Array new_axis; std::unordered_map dom_map; - std::unordered_map vsub; - std::unordered_map vsub2newvar; + std::unordered_map vsub; + std::unordered_map vsub2newvar; std::vector predicates; PrepareAxisMapping(orig_stage, compute, @@ -386,8 +386,8 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, Array new_axis; std::unordered_map dom_map; - std::unordered_map vsub; - std::unordered_map vsub2newvar; + std::unordered_map vsub; + std::unordered_map vsub2newvar; std::vector predicates; PrepareAxisMapping(orig_stage, tensor_op, @@ -763,7 +763,7 @@ Array Schedule::rfactor(const Tensor& tensor, predicates.push_back(reduce->condition); Expr predicate = likely(arith::ComputeReduce(predicates, Expr())); - std::unordered_map vsub; + std::unordered_map vsub; for (IterVar iv : compute_op->reduce_axis) { if (!touch_map.count(iv)) { diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index 2d494522b211..75f675d7663c 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -292,7 +292,7 @@ class SchedulePostProc : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } - Expr VisitExpr_(const Variable* op) final { + Expr VisitExpr_(const VarNode* op) final { auto it = var_value_.find(op); if (it != var_value_.end()) { return it->second; @@ -345,7 +345,7 @@ class SchedulePostProc : public StmtExprMutator { // The thread extent scope. std::unordered_map thread_extent_scope_; // The scan value - std::unordered_map var_value_; + std::unordered_map var_value_; // buffer replacement std::unordered_map replace_buffer_; // buffere realization to be replaced diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index a37f6f97d920..eaaec172f971 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -31,7 +31,7 @@ TEST(IRF, Basic) { auto z = x + 1; NodeFunctor f; - f.set_dispatch([](const ObjectRef& n, int b) { + f.set_dispatch([](const ObjectRef& n, int b) { return b; }); f.set_dispatch([](const ObjectRef& n, int b) { @@ -48,7 +48,7 @@ TEST(IRF, CountVar) { auto z = x + 1 + y + y; ir::PostOrderVisit(z, [&n_var](const ObjectRef& n) { - if (n.as()) ++n_var; + if (n.as()) ++n_var; }); CHECK_EQ(n_var, 2); } @@ -63,7 +63,7 @@ TEST(IRF, ExprTransform) { class MyExprFunctor : public ir::ExprFunctor { public: - int VisitExpr_(const Variable* op, int b) final { + int VisitExpr_(const VarNode* op, int b) final { return b; } int VisitExpr_(const IntImm* op, int b) final { diff --git a/topi/include/topi/detail/broadcast.h b/topi/include/topi/detail/broadcast.h index 4fdd18626498..8c5068a2f35d 100644 --- a/topi/include/topi/detail/broadcast.h +++ b/topi/include/topi/detail/broadcast.h @@ -52,8 +52,8 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, int i; for (i = 1; i <= std::min(s1_size, s2_size); ++i) { // TODO(@icemelon9): Need to revisit this part - const Variable* var1 = shape1[s1_size - i].as(); - const Variable* var2 = shape2[s2_size - i].as(); + const VarNode* var1 = shape1[s1_size - i].as(); + const VarNode* var2 = shape2[s2_size - i].as(); bh.all_vars.push_front(tvm::Var()); if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) { bh.common_shape.push_front(shape1[s1_size - i]); From 1fe787f330faa931d2b2cd879a48b82d95d8c514 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 7 Jan 2020 14:52:33 -0800 Subject: [PATCH 02/11] [REFACTOR][IR] Add/Sub/Mul/Div -> AddNode/SubNode etc. --- include/tvm/ir.h | 14 +++---- include/tvm/ir_functor_ext.h | 48 ++++++++++++------------ src/api/api_ir.cc | 12 +++--- src/arithmetic/bound_deducer.cc | 6 +-- src/arithmetic/canonical_simplify.cc | 36 +++++++++--------- src/arithmetic/compute_expr.h | 10 ++--- src/arithmetic/const_fold.h | 10 ++--- src/arithmetic/const_int_bound.cc | 16 ++++---- src/arithmetic/detect_linear_equation.cc | 6 +-- src/arithmetic/int_operator.h | 8 ++-- src/arithmetic/int_set.cc | 24 ++++++------ src/arithmetic/modular_set.cc | 10 ++--- src/arithmetic/pattern_match.h | 20 +++++----- src/arithmetic/rewrite_simplify.cc | 34 ++++++++--------- src/arithmetic/rewrite_simplify.h | 12 +++--- src/autotvm/touch_extractor.cc | 2 +- src/autotvm/touch_extractor.h | 10 ++--- src/codegen/codegen_c.cc | 12 +++--- src/codegen/codegen_c.h | 12 +++--- src/codegen/llvm/codegen_llvm.cc | 10 ++--- src/codegen/llvm/codegen_llvm.h | 12 +++--- src/codegen/llvm/codegen_x86_64.cc | 4 +- src/codegen/spirv/codegen_spirv.cc | 12 +++--- src/codegen/spirv/codegen_spirv.h | 12 +++--- src/codegen/stackvm/codegen_stackvm.cc | 12 +++--- src/codegen/stackvm/codegen_stackvm.h | 12 +++--- src/contrib/hybrid/codegen_hybrid.cc | 12 +++--- src/contrib/hybrid/codegen_hybrid.h | 12 +++--- src/lang/attr_functor.h | 48 ++++++++++++------------ src/lang/attrs.cc | 28 +++++++------- src/lang/buffer.cc | 18 ++++----- src/lang/expr_operator.cc | 32 ++++++++-------- src/lang/ir.cc | 40 ++++++++++---------- src/pass/bound_checker.cc | 16 ++++---- src/pass/inject_copy_intrin.cc | 2 +- src/pass/inject_double_buffer.cc | 2 +- src/pass/inject_virtual_thread.cc | 2 +- src/pass/ir_deep_compare.cc | 14 +++---- src/pass/ir_functor.cc | 26 ++++++------- src/pass/lower_custom_datatypes.cc | 36 +++++++++--------- src/pass/lower_intrin.cc | 16 ++++---- src/pass/lower_tvm_builtin.cc | 4 +- src/pass/make_api.cc | 2 +- src/pass/rewrite_unsafe_select.cc | 12 +++--- src/pass/storage_flatten.cc | 6 +-- src/pass/storage_rewrite.cc | 4 +- src/pass/tensor_core.cc | 28 +++++++------- src/pass/vectorize_loop.cc | 16 ++++---- src/relay/op/nn/upsampling.cc | 10 ++--- tests/cpp/ir_functor_test.cc | 8 ++-- tests/cpp/ir_simplify_test.cc | 2 +- tests/cpp/pattern_match_test.cc | 8 ++-- 52 files changed, 390 insertions(+), 390 deletions(-) diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 5dd05beb7302..580c9b0b2455 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -94,7 +94,7 @@ class StringImm : public ExprNode { * \brief Cast value from one data type to another. * \note The lanes of value should keep fixed. */ -class Cast : public ExprNode { +class CastNode : public ExprNode { public: /*! \brief Original data type. */ Expr value; @@ -107,7 +107,7 @@ class Cast : public ExprNode { TVM_DLL static Expr make(DataType t, Expr v); static constexpr const char* _type_key = "Cast"; - TVM_DECLARE_FINAL_OBJECT_INFO(Cast, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, ExprNode); }; /*! @@ -143,19 +143,19 @@ class BinaryOpNode : public ExprNode { }; /*! \brief a + b */ -class Add : public BinaryOpNode { +class AddNode : public BinaryOpNode { public: static constexpr const char* _type_key = "Add"; }; /*! \brief a - b */ -class Sub : public BinaryOpNode { +class SubNode : public BinaryOpNode { public: static constexpr const char* _type_key = "Sub"; }; /*! \brief a * b */ -class Mul : public BinaryOpNode { +class MulNode : public BinaryOpNode { public: static constexpr const char* _type_key = "Mul"; }; @@ -164,7 +164,7 @@ class Mul : public BinaryOpNode { * \brief a / b in the C semnatics. * \note For integer division, C standard uses trunc div. */ -class Div : public BinaryOpNode
{ +class DivNode : public BinaryOpNode { public: static constexpr const char* _type_key = "Div"; }; @@ -173,7 +173,7 @@ class Div : public BinaryOpNode
{ * \brief a % b in the C semnatics. * \note For integer division, C standard uses trunc div. */ -class Mod : public BinaryOpNode { +class ModNode : public BinaryOpNode { public: static constexpr const char* _type_key = "Mod"; }; diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index 6cc6d702c7cd..19fc345290ad 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -136,11 +136,11 @@ class ExprFunctor { virtual R VisitExpr_(const Load* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Let* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Call* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Add* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Sub* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Mul* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Div* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Mod* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const AddNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const SubNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const MulNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const DivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloorDiv* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloorMod* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -154,7 +154,7 @@ class ExprFunctor { virtual R VisitExpr_(const And* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Or* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Reduce* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Cast* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const CastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Not* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -178,11 +178,11 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(Load); IR_EXPR_FUNCTOR_DISPATCH(Let); IR_EXPR_FUNCTOR_DISPATCH(Call); - IR_EXPR_FUNCTOR_DISPATCH(Add); - IR_EXPR_FUNCTOR_DISPATCH(Sub); - IR_EXPR_FUNCTOR_DISPATCH(Mul); - IR_EXPR_FUNCTOR_DISPATCH(Div); - IR_EXPR_FUNCTOR_DISPATCH(Mod); + IR_EXPR_FUNCTOR_DISPATCH(AddNode); + IR_EXPR_FUNCTOR_DISPATCH(SubNode); + IR_EXPR_FUNCTOR_DISPATCH(MulNode); + IR_EXPR_FUNCTOR_DISPATCH(DivNode); + IR_EXPR_FUNCTOR_DISPATCH(ModNode); IR_EXPR_FUNCTOR_DISPATCH(FloorDiv); IR_EXPR_FUNCTOR_DISPATCH(FloorMod); IR_EXPR_FUNCTOR_DISPATCH(Min); @@ -196,7 +196,7 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(And); IR_EXPR_FUNCTOR_DISPATCH(Or); IR_EXPR_FUNCTOR_DISPATCH(Reduce); - IR_EXPR_FUNCTOR_DISPATCH(Cast); + IR_EXPR_FUNCTOR_DISPATCH(CastNode); IR_EXPR_FUNCTOR_DISPATCH(Not); IR_EXPR_FUNCTOR_DISPATCH(Select); IR_EXPR_FUNCTOR_DISPATCH(Ramp); @@ -302,11 +302,11 @@ class TVM_DLL ExprVisitor : void VisitExpr_(const Load* op) override; void VisitExpr_(const Let* op) override; void VisitExpr_(const Call* op) override; - void VisitExpr_(const Add* op) override; - void VisitExpr_(const Sub* op) override; - void VisitExpr_(const Mul* op) override; - void VisitExpr_(const Div* op) override; - void VisitExpr_(const Mod* op) override; + void VisitExpr_(const AddNode* op) override; + void VisitExpr_(const SubNode* op) override; + void VisitExpr_(const MulNode* op) override; + void VisitExpr_(const DivNode* op) override; + void VisitExpr_(const ModNode* op) override; void VisitExpr_(const FloorDiv* op) override; void VisitExpr_(const FloorMod* op) override; void VisitExpr_(const Min* op) override; @@ -320,7 +320,7 @@ class TVM_DLL ExprVisitor : void VisitExpr_(const And* op) override; void VisitExpr_(const Or* op) override; void VisitExpr_(const Reduce* op) override; - void VisitExpr_(const Cast* op) override; + void VisitExpr_(const CastNode* op) override; void VisitExpr_(const Not* op) override; void VisitExpr_(const Select* op) override; void VisitExpr_(const Ramp* op) override; @@ -347,11 +347,11 @@ class TVM_DLL ExprMutator : Expr VisitExpr_(const Load* op) override; Expr VisitExpr_(const Let* op) override; Expr VisitExpr_(const Call* op) override; - Expr VisitExpr_(const Add* op) override; - Expr VisitExpr_(const Sub* op) override; - Expr VisitExpr_(const Mul* op) override; - Expr VisitExpr_(const Div* op) override; - Expr VisitExpr_(const Mod* op) override; + Expr VisitExpr_(const AddNode* op) override; + Expr VisitExpr_(const SubNode* op) override; + Expr VisitExpr_(const MulNode* op) override; + Expr VisitExpr_(const DivNode* op) override; + Expr VisitExpr_(const ModNode* op) override; Expr VisitExpr_(const FloorDiv* op) override; Expr VisitExpr_(const FloorMod* op) override; Expr VisitExpr_(const Min* op) override; @@ -365,7 +365,7 @@ class TVM_DLL ExprMutator : Expr VisitExpr_(const And* op) override; Expr VisitExpr_(const Or* op) override; Expr VisitExpr_(const Reduce* op) override; - Expr VisitExpr_(const Cast* op) override; + Expr VisitExpr_(const CastNode* op) override; Expr VisitExpr_(const Not* op) override; Expr VisitExpr_(const Select* op) override; Expr VisitExpr_(const Ramp* op) override; diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 034405f1a7f0..6d271f6b77ca 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -134,11 +134,11 @@ REGISTER_MAKE(UIntImm); REGISTER_MAKE(FloatImm); REGISTER_MAKE(StringImm); -REGISTER_MAKE(Add); -REGISTER_MAKE(Sub); -REGISTER_MAKE(Mul); -REGISTER_MAKE(Div); -REGISTER_MAKE(Mod); +REGISTER_MAKE(AddNode); +REGISTER_MAKE(SubNode); +REGISTER_MAKE(MulNode); +REGISTER_MAKE(DivNode); +REGISTER_MAKE(ModNode); REGISTER_MAKE(FloorDiv); REGISTER_MAKE(FloorMod); REGISTER_MAKE(Min); @@ -155,7 +155,7 @@ REGISTER_MAKE(Or); REGISTER_MAKE(Not); REGISTER_MAKE(Select); REGISTER_MAKE(Ramp); -REGISTER_MAKE(Cast); +REGISTER_MAKE(CastNode); REGISTER_MAKE(Broadcast); REGISTER_MAKE(Shuffle); REGISTER_MAKE(Let); diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index b9c423c0ad47..65574ef6327c 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -110,13 +110,13 @@ class BoundDeducer: public ExprVisitor { LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } - void VisitExpr_(const Add* op) final { + void VisitExpr_(const AddNode* op) final { bool left = op->a.get() == path_[iter_]; result_ -= left ? op->b : op->a; this->VisitExpr(left ? op->a : op->b); } - void VisitExpr_(const Sub* op) final { + void VisitExpr_(const SubNode* op) final { bool left = op->a.get() == path_[iter_]; if (left) { result_ += op->b; @@ -128,7 +128,7 @@ class BoundDeducer: public ExprVisitor { this->VisitExpr(left ? op->a : op->b); } - void VisitExpr_(const Mul* op) final { + void VisitExpr_(const MulNode* op) final { bool left = op->a.get() == path_[iter_]; Expr operand = left ? op->b : op->a; Expr target_var = left ? op->a : op->b; diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index d05ee2dd9a30..1ede2453198d 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -450,11 +450,11 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { } using Rewriter::VisitExpr_; - Expr VisitExpr_(const Add* op) final; - Expr VisitExpr_(const Sub* op) final; - Expr VisitExpr_(const Mul* op) final; - Expr VisitExpr_(const Div* op) final; - Expr VisitExpr_(const Mod* op) final; + Expr VisitExpr_(const AddNode* op) final; + Expr VisitExpr_(const SubNode* op) final; + Expr VisitExpr_(const MulNode* op) final; + Expr VisitExpr_(const DivNode* op) final; + Expr VisitExpr_(const ModNode* op) final; Expr VisitExpr_(const FloorDiv* op) final; Expr VisitExpr_(const FloorMod* op) final; Expr VisitExpr_(const Reduce* op) final; @@ -566,7 +566,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { }; Expr CanonicalSimplifier::Impl:: -VisitExpr_(const Add* op) { +VisitExpr_(const AddNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -575,7 +575,7 @@ VisitExpr_(const Add* op) { Expr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + Expr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; // canonical form simplification. @@ -592,7 +592,7 @@ VisitExpr_(const Add* op) { } Expr CanonicalSimplifier::Impl:: -VisitExpr_(const Sub* op) { +VisitExpr_(const SubNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -601,7 +601,7 @@ VisitExpr_(const Sub* op) { Expr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + Expr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; // canonical form simplification. @@ -619,7 +619,7 @@ VisitExpr_(const Sub* op) { Expr CanonicalSimplifier::Impl:: -VisitExpr_(const Mul* op) { +VisitExpr_(const MulNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -628,7 +628,7 @@ VisitExpr_(const Mul* op) { Expr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + Expr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; // x * c @@ -653,7 +653,7 @@ VisitExpr_(const Mul* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return Mul::make(a, b); + return MulNode::make(a, b); } } @@ -726,7 +726,7 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { } Expr CanonicalSimplifier::Impl:: -VisitExpr_(const Div* op) { +VisitExpr_(const DivNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -735,7 +735,7 @@ VisitExpr_(const Div* op) { Expr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold
(a, b); + Expr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; // x / c1 @@ -782,7 +782,7 @@ VisitExpr_(const Div* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return Div::make(a, b); + return DivNode::make(a, b); } } @@ -893,7 +893,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { } Expr CanonicalSimplifier::Impl:: -VisitExpr_(const Mod* op) { +VisitExpr_(const ModNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -902,7 +902,7 @@ VisitExpr_(const Mod* op) { Expr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + Expr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; @@ -958,7 +958,7 @@ VisitExpr_(const Mod* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return Mod::make(a, b); + return ModNode::make(a, b); } } diff --git a/src/arithmetic/compute_expr.h b/src/arithmetic/compute_expr.h index 806587ab75aa..36571078f0b1 100644 --- a/src/arithmetic/compute_expr.h +++ b/src/arithmetic/compute_expr.h @@ -77,27 +77,27 @@ inline bool GetConstInt(Expr e, int* out) { } template<> -inline Expr Compute(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return a + b; } template<> -inline Expr Compute(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return a - b; } template<> -inline Expr Compute(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return a * b; } template<> -inline Expr Compute(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return truncdiv(a, b); } template<> -inline Expr Compute(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return truncmod(a, b); } diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index 8b4ea2fa8133..dbbbdcfbecda 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -100,7 +100,7 @@ inline bool IsIndexType(const DataType& type) { // specialization of constant folders. template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm::make(rtype, pa->value + pb->value); @@ -114,7 +114,7 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm::make(rtype, pa->value - pb->value); @@ -126,7 +126,7 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm::make(rtype, pa->value * pb->value); @@ -152,7 +152,7 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -181,7 +181,7 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index 9a25c47f37a0..9a67a1b759c2 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -140,7 +140,7 @@ class ConstIntBoundAnalyzer::Impl : return res; } - Entry VisitExpr_(const Cast* op) final { + Entry VisitExpr_(const CastNode* op) final { Entry a = VisitExpr(op->value); Entry b = Everything(op->dtype); return Intersect(a, b); @@ -158,7 +158,7 @@ class ConstIntBoundAnalyzer::Impl : } } - Entry VisitExpr_(const Add* op) final { + Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); Entry ret; @@ -167,7 +167,7 @@ class ConstIntBoundAnalyzer::Impl : return ret; } - Entry VisitExpr_(const Sub* op) final { + Entry VisitExpr_(const SubNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); Entry ret; @@ -176,13 +176,13 @@ class ConstIntBoundAnalyzer::Impl : return ret; } - Entry VisitExpr_(const Mul* op) final { + Entry VisitExpr_(const MulNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); return BinaryOpBoundry(a, b, InfAwareMul); } - Entry VisitExpr_(const Div* op) final { + Entry VisitExpr_(const DivNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); CHECK(!b.is_const(0)) << "divide by zero"; @@ -192,7 +192,7 @@ class ConstIntBoundAnalyzer::Impl : return BinaryOpBoundry(a, b, InfAwareDiv); } - Entry VisitExpr_(const Mod* op) final { + Entry VisitExpr_(const ModNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); if (b.min_value > 0) { @@ -375,7 +375,7 @@ class ConstIntBoundAnalyzer::Impl : return kNegInf; } if (y == kPosInf || y == kNegInf) return y; - if (WillOverflow(x, y, kNegInf, kPosInf)) { + if (WillOverflow(x, y, kNegInf, kPosInf)) { if (x > 0) return kPosInf; return kNegInf; } @@ -388,7 +388,7 @@ class ConstIntBoundAnalyzer::Impl : * \return the result. */ static int64_t InfAwareMul(int64_t x, int64_t y) { - if (!WillOverflow(x, y, kNegInf, kPosInf)) return x * y; + if (!WillOverflow(x, y, kNegInf, kPosInf)) return x * y; if ((x > 0 && y > 0) || (x < 0 && y < 0)) return kPosInf; return kNegInf; } diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index 5c964d55bc9c..b9a9a1ecb77e 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -60,7 +60,7 @@ class LinearEqDetector return true; } - LinearEqEntry VisitExpr_(const Add* op, const Expr& e) final { + LinearEqEntry VisitExpr_(const AddNode* op, const Expr& e) final { if (fail_) return LinearEqEntry(); LinearEqEntry a = VisitExpr(op->a, op->a); LinearEqEntry b = VisitExpr(op->b, op->b); @@ -70,7 +70,7 @@ class LinearEqDetector return ret; } - LinearEqEntry VisitExpr_(const Sub* op, const Expr& e) final { + LinearEqEntry VisitExpr_(const SubNode* op, const Expr& e) final { if (fail_) return LinearEqEntry(); LinearEqEntry a = VisitExpr(op->a, op->a); LinearEqEntry b = VisitExpr(op->b, op->b); @@ -80,7 +80,7 @@ class LinearEqDetector return ret; } - LinearEqEntry VisitExpr_(const Mul* op, const Expr& e) final { + LinearEqEntry VisitExpr_(const MulNode* op, const Expr& e) final { if (fail_) return LinearEqEntry(); LinearEqEntry a = VisitExpr(op->a, op->a); LinearEqEntry b = VisitExpr(op->b, op->b); diff --git a/src/arithmetic/int_operator.h b/src/arithmetic/int_operator.h index e3adf1fa269f..9379eeb86f17 100644 --- a/src/arithmetic/int_operator.h +++ b/src/arithmetic/int_operator.h @@ -47,7 +47,7 @@ inline bool WillOverflow(int64_t x, } template<> -inline bool WillOverflow(int64_t x, +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { @@ -57,7 +57,7 @@ inline bool WillOverflow(int64_t x, } template<> -inline bool WillOverflow(int64_t x, +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { @@ -67,7 +67,7 @@ inline bool WillOverflow(int64_t x, } template<> -inline bool WillOverflow(int64_t x, +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { @@ -84,7 +84,7 @@ inline bool WillOverflow(int64_t x, } template<> -inline bool WillOverflow(int64_t x, +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 042d85e28751..86cb1bf622c3 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -118,7 +118,7 @@ inline IntervalSet Combine(Analyzer* analyzer, } template<> -inline IntervalSet Combine(Analyzer* analyer, +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -136,7 +136,7 @@ inline IntervalSet Combine(Analyzer* analyer, } template<> -inline IntervalSet Combine(Analyzer* analyer, +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -155,7 +155,7 @@ inline IntervalSet Combine(Analyzer* analyer, template<> -inline IntervalSet Combine(Analyzer* analyzer, +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -190,7 +190,7 @@ inline IntervalSet Combine(Analyzer* analyzer, } template<> -inline IntervalSet Combine(Analyzer* analyzer, +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -225,7 +225,7 @@ inline IntervalSet Combine(Analyzer* analyzer, } template<> -inline IntervalSet Combine(Analyzer* analyzer, +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -405,23 +405,23 @@ class IntervalSetEvaluator : } } - IntervalSet VisitExpr_(const Add* op) final { + IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const Sub* op) final { + IntervalSet VisitExpr_(const SubNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const Mul* op) final { + IntervalSet VisitExpr_(const MulNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const Div* op) final { + IntervalSet VisitExpr_(const DivNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const Mod* op) final { + IntervalSet VisitExpr_(const ModNode* op) final { return VisitBinaryExpr_(op); } @@ -481,12 +481,12 @@ class IntervalSetEvaluator : DataType t = op->base.dtype(); int64_t vstride = stride.Eval()->value; if (vstride> 0) { - return Combine( + return Combine( analyzer_, base, IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1))); } else { - return Combine( + return Combine( analyzer_, base, IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t))); diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index f40492325ca4..37fa30debc2d 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -124,7 +124,7 @@ class ModularSetAnalyzer::Impl : return Everything(); } - Entry VisitExpr_(const Cast* op) final { + Entry VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } @@ -140,21 +140,21 @@ class ModularSetAnalyzer::Impl : } } - Entry VisitExpr_(const Add* op) final { + Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); return Entry(coeff, a.base + b.base); } - Entry VisitExpr_(const Sub* op) final { + Entry VisitExpr_(const SubNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); return Entry(coeff, a.base - b.base); } - Entry VisitExpr_(const Mul* op) final { + Entry VisitExpr_(const MulNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); // Simplification rule, x, y, z are in Z @@ -188,7 +188,7 @@ class ModularSetAnalyzer::Impl : return Everything(); } - Entry VisitExpr_(const Div* op) final { + Entry VisitExpr_(const DivNode* op) final { Entry b = VisitExpr(op->b); if (b.is_const()) { return DivByConst(op->a, b.base, false); diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h index bff956473c87..7e20302ca8f8 100644 --- a/src/arithmetic/pattern_match.h +++ b/src/arithmetic/pattern_match.h @@ -325,18 +325,18 @@ class PConstWithTypeLike : // raise ambiguity error for operator overload of / and % -TVM_PATTERN_BINARY_OP_EX(operator/, ir::Div, DivAmbiguityError(a)); -TVM_PATTERN_BINARY_OP_EX(operator%, ir::Mod, DivAmbiguityError(a)); +TVM_PATTERN_BINARY_OP_EX(operator/, ir::DivNode, DivAmbiguityError(a)); +TVM_PATTERN_BINARY_OP_EX(operator%, ir::ModNode, DivAmbiguityError(a)); // arithmetic expressions -TVM_PATTERN_BINARY_OP(operator+, ir::Add); -TVM_PATTERN_BINARY_OP(operator-, ir::Sub); -TVM_PATTERN_BINARY_OP(operator*, ir::Mul); +TVM_PATTERN_BINARY_OP(operator+, ir::AddNode); +TVM_PATTERN_BINARY_OP(operator-, ir::SubNode); +TVM_PATTERN_BINARY_OP(operator*, ir::MulNode); TVM_PATTERN_BINARY_OP(min, ir::Min); TVM_PATTERN_BINARY_OP(max, ir::Max); -TVM_PATTERN_BINARY_OP(div, ir::Div); -TVM_PATTERN_BINARY_OP(truncdiv, ir::Div); -TVM_PATTERN_BINARY_OP(truncmod, ir::Mod); +TVM_PATTERN_BINARY_OP(div, ir::DivNode); +TVM_PATTERN_BINARY_OP(truncdiv, ir::DivNode); +TVM_PATTERN_BINARY_OP(truncmod, ir::ModNode); TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv); TVM_PATTERN_BINARY_OP(floormod, ir::FloorMod); @@ -473,7 +473,7 @@ class PCastExpr : } bool Match_(const ObjectRef& node) const { - if (const ir::Cast* ptr = node.as()) { + if (const ir::CastNode* ptr = node.as()) { if (!dtype_.Match_(ptr->dtype)) return false; if (!value_.Match_(ptr->value)) return false; return true; @@ -483,7 +483,7 @@ class PCastExpr : } Expr Eval() const { - return ir::Cast::make(dtype_.Eval(), value_.Eval()); + return ir::CastNode::make(dtype_.Eval(), value_.Eval()); } private: diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 2c1fa5dc69f9..9e81645e5a02 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -116,10 +116,10 @@ Update(const Var& var, const Expr& info, bool override) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Add* op) { +VisitExpr_(const AddNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; @@ -231,10 +231,10 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const Expr& const } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Sub* op) { +VisitExpr_(const SubNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; @@ -430,10 +430,10 @@ VisitExpr_(const Sub* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Mul* op) { +VisitExpr_(const MulNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; @@ -469,10 +469,10 @@ VisitExpr_(const Mul* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Div* op) { +VisitExpr_(const DivNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as
(); - Expr const_res = TryConstFold
(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1; @@ -691,10 +691,10 @@ VisitExpr_(const Div* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Mod* op) { +VisitExpr_(const ModNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1738,9 +1738,9 @@ VisitExpr_(const VarNode* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Cast* op) { +VisitExpr_(const CastNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); + op = ret.as(); return cast(op->dtype, op->value); } diff --git a/src/arithmetic/rewrite_simplify.h b/src/arithmetic/rewrite_simplify.h index 4984bc524924..39f5cf4f9954 100644 --- a/src/arithmetic/rewrite_simplify.h +++ b/src/arithmetic/rewrite_simplify.h @@ -50,11 +50,11 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { : IRMutatorWithAnalyzer(parent) {} void Update(const Var& var, const Expr& info, bool override_info); - Expr VisitExpr_(const Add* op) override; - Expr VisitExpr_(const Sub* op) override; - Expr VisitExpr_(const Mul* op) override; - Expr VisitExpr_(const Div* op) override; - Expr VisitExpr_(const Mod* op) override; + Expr VisitExpr_(const AddNode* op) override; + Expr VisitExpr_(const SubNode* op) override; + Expr VisitExpr_(const MulNode* op) override; + Expr VisitExpr_(const DivNode* op) override; + Expr VisitExpr_(const ModNode* op) override; Expr VisitExpr_(const FloorDiv* op) override; Expr VisitExpr_(const FloorMod* op) override; Expr VisitExpr_(const Min* op) override; @@ -71,7 +71,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { Expr VisitExpr_(const Select* op) override; Expr VisitExpr_(const Call* op) override; Expr VisitExpr_(const VarNode* op) override; - Expr VisitExpr_(const Cast* op) override; + Expr VisitExpr_(const CastNode* op) override; Expr VisitExpr_(const Let* op) override; std::function EnterConstraint(const Expr& constraint); diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index c986ef7cbc2d..9f906220c5be 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -60,7 +60,7 @@ class IndexParser: public ExprVisitor { } } - void VisitExpr_(const Mul* op) final { + void VisitExpr_(const MulNode* op) final { if (op->a.as()) { if (const auto stride = op->b.as()) { next_stride_ = stride->value; diff --git a/src/autotvm/touch_extractor.h b/src/autotvm/touch_extractor.h index 2bcf6b808a98..5265aad9df06 100644 --- a/src/autotvm/touch_extractor.h +++ b/src/autotvm/touch_extractor.h @@ -92,31 +92,31 @@ class TouchExtractor : public FeatureVisitor { } // arithmetic stats - void VisitExpr_(const Add* op) final { + void VisitExpr_(const AddNode* op) final { if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; FeatureVisitor::VisitExpr_(op); } - void VisitExpr_(const Sub* op) final { + void VisitExpr_(const SubNode* op) final { if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; FeatureVisitor::VisitExpr_(op); } - void VisitExpr_(const Mul* op) final { + void VisitExpr_(const MulNode* op) final { if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].mul_ct++; FeatureVisitor::VisitExpr_(op); } - void VisitExpr_(const Div* op) final { + void VisitExpr_(const DivNode* op) final { if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; FeatureVisitor::VisitExpr_(op); } - void VisitExpr_(const Mod* op) final { + void VisitExpr_(const ModNode* op) final { if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; FeatureVisitor::VisitExpr_(op); diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index bacfed02efaa..523a565a866d 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -457,7 +457,7 @@ inline void PrintBinaryIntrinsic(const Call* op, p->PrintVecBinaryOp(opstr, op->dtype, op->args[0], op->args[1], os); } } -void CodeGenC::VisitExpr_(const Cast* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*) std::stringstream value; this->PrintExpr(op->value, value); os << CastFromTo(value.str(), op->value.dtype(), op->dtype); @@ -465,19 +465,19 @@ void CodeGenC::VisitExpr_(const Cast* op, std::ostream& os) { // NOLINT(*) void CodeGenC::VisitExpr_(const VarNode* op, std::ostream& os) { // NOLINT(*) os << GetVarID(op); } -void CodeGenC::VisitExpr_(const Add* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const AddNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "+", os, this); } -void CodeGenC::VisitExpr_(const Sub* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const SubNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "-", os, this); } -void CodeGenC::VisitExpr_(const Mul* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const MulNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "*", os, this); } -void CodeGenC::VisitExpr_(const Div* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const DivNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "/", os, this); } -void CodeGenC::VisitExpr_(const Mod* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "%", os, this); } void CodeGenC::VisitExpr_(const Min* op, std::ostream& os) { // NOLINT(*) diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index b5555f25ace9..345e817e065f 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -106,11 +106,11 @@ class CodeGenC : void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*) @@ -121,7 +121,7 @@ class CodeGenC : void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*) diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index b0d86a9f66ce..6ea9d42eaa8d 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -797,7 +797,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) { return GetVarValue(op); } -llvm::Value* CodeGenLLVM::VisitExpr_(const Cast* op) { +llvm::Value* CodeGenLLVM::VisitExpr_(const CastNode* op) { return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value)); } llvm::Value* CodeGenLLVM::VisitExpr_(const IntImm* op) { @@ -836,8 +836,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const StringImm* op) { return builder_->CreateF ## Op (a, b); \ } \ } \ - llvm::Value* CodeGenLLVM::VisitExpr_(const Op* op) { \ - return Create ## Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \ + llvm::Value* CodeGenLLVM::VisitExpr_(const Op ## Node* op) { \ + return Create ## Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \ } DEFINE_CODEGEN_BINARY_OP(Add); @@ -865,7 +865,7 @@ DEFINE_CODEGEN_CMP_OP(LE); DEFINE_CODEGEN_CMP_OP(GT); DEFINE_CODEGEN_CMP_OP(GE); -llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) { +llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); if (op->dtype.is_int()) { @@ -878,7 +878,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Div* op) { } } -llvm::Value* CodeGenLLVM::VisitExpr_(const Mod* op) { +llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); if (op->dtype.is_int()) { diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 076ffb2af588..f1f48df78315 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -104,16 +104,16 @@ class CodeGenLLVM : } // override codegen llvm::Value* VisitExpr_(const Variable* op) override; - llvm::Value* VisitExpr_(const Cast* op) override; + llvm::Value* VisitExpr_(const CastNode* op) override; llvm::Value* VisitExpr_(const IntImm* op) override; llvm::Value* VisitExpr_(const UIntImm* op) override; llvm::Value* VisitExpr_(const FloatImm* op) override; llvm::Value* VisitExpr_(const StringImm* op) override; - llvm::Value* VisitExpr_(const Add* op) override; - llvm::Value* VisitExpr_(const Sub* op) override; - llvm::Value* VisitExpr_(const Mul* op) override; - llvm::Value* VisitExpr_(const Div* op) override; - llvm::Value* VisitExpr_(const Mod* op) override; + llvm::Value* VisitExpr_(const AddNode* op) override; + llvm::Value* VisitExpr_(const SubNode* op) override; + llvm::Value* VisitExpr_(const MulNode* op) override; + llvm::Value* VisitExpr_(const DivNode* op) override; + llvm::Value* VisitExpr_(const ModNode* op) override; llvm::Value* VisitExpr_(const Min* op) override; llvm::Value* VisitExpr_(const Max* op) override; llvm::Value* VisitExpr_(const LT* op) override; diff --git a/src/codegen/llvm/codegen_x86_64.cc b/src/codegen/llvm/codegen_x86_64.cc index d6138830bfb4..27b321131771 100644 --- a/src/codegen/llvm/codegen_x86_64.cc +++ b/src/codegen/llvm/codegen_x86_64.cc @@ -65,14 +65,14 @@ bool TargetHasFeature(const llvm::TargetMachine& tm, const std::string& feature) class CodeGenX86_64 final : public CodeGenCPU { public: - llvm::Value* VisitExpr_(const Cast* op) override; + llvm::Value* VisitExpr_(const CastNode* op) override; private: llvm::Value* CallVectorIntrin(llvm::Intrinsic::ID id, size_t intrin_lanes, llvm::Type* result_ty, const std::vector& args); }; -llvm::Value* CodeGenX86_64::VisitExpr_(const Cast* op) { +llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { // LLVM does not automatically generate the correct instruction sequences for // half -> float conversion (i.e. using AVX2/AVX-512 vectorized variants of // vcvtph2ps), so we explicitly generate them ourselves. diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index 72709ebf9def..cfc175ee0554 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -149,27 +149,27 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const StringImm* op) { return spirv::Value(); } -spirv::Value CodeGenSPIRV::VisitExpr_(const Cast* op) { +spirv::Value CodeGenSPIRV::VisitExpr_(const CastNode* op) { return builder_->Cast(builder_->GetSType(op->dtype), MakeValue(op->value)); } -spirv::Value CodeGenSPIRV::VisitExpr_(const Add* op) { +spirv::Value CodeGenSPIRV::VisitExpr_(const AddNode* op) { return builder_->Add(MakeValue(op->a), MakeValue(op->b)); } -spirv::Value CodeGenSPIRV::VisitExpr_(const Sub* op) { +spirv::Value CodeGenSPIRV::VisitExpr_(const SubNode* op) { return builder_->Sub(MakeValue(op->a), MakeValue(op->b)); } -spirv::Value CodeGenSPIRV::VisitExpr_(const Mul* op) { +spirv::Value CodeGenSPIRV::VisitExpr_(const MulNode* op) { return builder_->Mul(MakeValue(op->a), MakeValue(op->b)); } -spirv::Value CodeGenSPIRV::VisitExpr_(const Div* op) { +spirv::Value CodeGenSPIRV::VisitExpr_(const DivNode* op) { return builder_->Div(MakeValue(op->a), MakeValue(op->b)); } -spirv::Value CodeGenSPIRV::VisitExpr_(const Mod* op) { +spirv::Value CodeGenSPIRV::VisitExpr_(const ModNode* op) { return builder_->Mod(MakeValue(op->a), MakeValue(op->b)); } diff --git a/src/codegen/spirv/codegen_spirv.h b/src/codegen/spirv/codegen_spirv.h index c6833057418d..6839ed2a4635 100644 --- a/src/codegen/spirv/codegen_spirv.h +++ b/src/codegen/spirv/codegen_spirv.h @@ -63,16 +63,16 @@ class CodeGenSPIRV: } // override codegen spirv::Value VisitExpr_(const VarNode* op) override; - spirv::Value VisitExpr_(const Cast* op) override; + spirv::Value VisitExpr_(const CastNode* op) override; spirv::Value VisitExpr_(const IntImm* op) override; spirv::Value VisitExpr_(const UIntImm* op) override; spirv::Value VisitExpr_(const FloatImm* op) override; spirv::Value VisitExpr_(const StringImm* op) override; - spirv::Value VisitExpr_(const Add* op) override; - spirv::Value VisitExpr_(const Sub* op) override; - spirv::Value VisitExpr_(const Mul* op) override; - spirv::Value VisitExpr_(const Div* op) override; - spirv::Value VisitExpr_(const Mod* op) override; + spirv::Value VisitExpr_(const AddNode* op) override; + spirv::Value VisitExpr_(const SubNode* op) override; + spirv::Value VisitExpr_(const MulNode* op) override; + spirv::Value VisitExpr_(const DivNode* op) override; + spirv::Value VisitExpr_(const ModNode* op) override; spirv::Value VisitExpr_(const Min* op) override; spirv::Value VisitExpr_(const Max* op) override; spirv::Value VisitExpr_(const LT* op) override; diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index ce0f45b6ec10..21f7df811fea 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -295,28 +295,28 @@ void CodeGenStackVM::VisitExpr_(const VarNode* op) { this->PushOp(StackVM::LOAD_HEAP, vid); } -void CodeGenStackVM::VisitExpr_(const Cast* op) { +void CodeGenStackVM::VisitExpr_(const CastNode* op) { this->Push(op->value); PushCast(op->dtype, op->value.dtype()); } -void CodeGenStackVM::VisitExpr_(const Add* op) { +void CodeGenStackVM::VisitExpr_(const AddNode* op) { PushBinary(StackVM::ADD_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const Sub* op) { +void CodeGenStackVM::VisitExpr_(const SubNode* op) { PushBinary(StackVM::SUB_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const Mul* op) { +void CodeGenStackVM::VisitExpr_(const MulNode* op) { PushBinary(StackVM::MUL_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const Div* op) { +void CodeGenStackVM::VisitExpr_(const DivNode* op) { PushBinary(StackVM::DIV_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const Mod* op) { +void CodeGenStackVM::VisitExpr_(const ModNode* op) { PushBinary(StackVM::MOD_I64, op->a, op->b); } diff --git a/src/codegen/stackvm/codegen_stackvm.h b/src/codegen/stackvm/codegen_stackvm.h index 36287f783c58..283a31dfd754 100644 --- a/src/codegen/stackvm/codegen_stackvm.h +++ b/src/codegen/stackvm/codegen_stackvm.h @@ -115,11 +115,11 @@ class CodeGenStackVM void VisitExpr_(const Load* op) final; void VisitExpr_(const Let* op) final; void VisitExpr_(const Call* op) final; - void VisitExpr_(const Add* op) final; - void VisitExpr_(const Sub* op) final; - void VisitExpr_(const Mul* op) final; - void VisitExpr_(const Div* op) final; - void VisitExpr_(const Mod* op) final; + void VisitExpr_(const AddNode* op) final; + void VisitExpr_(const SubNode* op) final; + void VisitExpr_(const MulNode* op) final; + void VisitExpr_(const DivNode* op) final; + void VisitExpr_(const ModNode* op) final; void VisitExpr_(const Min* op) final; void VisitExpr_(const Max* op) final; void VisitExpr_(const EQ* op) final; @@ -130,7 +130,7 @@ class CodeGenStackVM void VisitExpr_(const GE* op) final; void VisitExpr_(const And* op) final; void VisitExpr_(const Or* op) final; - void VisitExpr_(const Cast* op) final; + void VisitExpr_(const CastNode* op) final; void VisitExpr_(const Not* op) final; void VisitExpr_(const Select* op) final; void VisitExpr_(const Ramp* op) final; diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index ea2bf4d92112..484d589dea7e 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -127,7 +127,7 @@ inline void PrintBinaryIntrinsitc(const Call* op, os << ')'; } -void CodeGenHybrid::VisitExpr_(const Cast* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*) if (op->dtype == op->value.dtype()) { PrintExpr(op->value, stream); } else { @@ -141,17 +141,17 @@ void CodeGenHybrid::VisitExpr_(const Cast* op, std::ostream& os) { // NOLINT(*) void CodeGenHybrid::VisitExpr_(const VarNode* op, std::ostream& os) { // NOLINT(*) os << GetVarID(op); } -void CodeGenHybrid::VisitExpr_(const Add* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const AddNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "+", os, this); } -void CodeGenHybrid::VisitExpr_(const Sub* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const SubNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "-", os, this); } -void CodeGenHybrid::VisitExpr_(const Mul* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const MulNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "*", os, this); } -void CodeGenHybrid::VisitExpr_(const Div* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const DivNode* op, std::ostream& os) { // NOLINT(*) if (op->dtype.is_int()) PrintBinaryExpr(op, "//", os, this); else @@ -165,7 +165,7 @@ void CodeGenHybrid::VisitExpr_(const FloorDiv* op, std::ostream& os) { // NOLIN PrintBinaryExpr(op, "/", os, this); } -void CodeGenHybrid::VisitExpr_(const Mod* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "%", os, this); } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index a43b98aa6174..63bf1f74eb9d 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -94,11 +94,11 @@ class CodeGenHybrid : void VisitExpr_(const Load* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Let* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Call* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const Add* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const Sub* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const Mul* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const Div* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const Mod* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloorDiv* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloorMod* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*) @@ -111,7 +111,7 @@ class CodeGenHybrid : void VisitExpr_(const GE* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const And* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Or* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const Cast* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Not* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Select* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const Ramp* op, std::ostream& os) override; // NOLINT(*) diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index 49ab2fd2a2ba..bf18085c2a61 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -82,11 +82,11 @@ class AttrFunctor { virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT; // deep comparison of symbolic integer expressions. virtual R VisitAttr_(const VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::Add* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::Sub* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::Div* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::AddNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::SubNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::DivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::ModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::FloorDiv* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::FloorMod* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT; @@ -100,7 +100,7 @@ class AttrFunctor { virtual R VisitAttr_(const ir::And* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Or* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Not* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::Cast* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::CastNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Call* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Select* op, Args... args) ATTR_FUNCTOR_DEFAULT; @@ -117,11 +117,11 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(FloatImm); ATTR_FUNCTOR_DISPATCH(StringImm); ATTR_FUNCTOR_DISPATCH(VarNode); - ATTR_FUNCTOR_DISPATCH(Add); - ATTR_FUNCTOR_DISPATCH(Sub); - ATTR_FUNCTOR_DISPATCH(Mul); - ATTR_FUNCTOR_DISPATCH(Div); - ATTR_FUNCTOR_DISPATCH(Mod); + ATTR_FUNCTOR_DISPATCH(AddNode); + ATTR_FUNCTOR_DISPATCH(SubNode); + ATTR_FUNCTOR_DISPATCH(MulNode); + ATTR_FUNCTOR_DISPATCH(DivNode); + ATTR_FUNCTOR_DISPATCH(ModNode); ATTR_FUNCTOR_DISPATCH(FloorDiv); ATTR_FUNCTOR_DISPATCH(FloorMod); ATTR_FUNCTOR_DISPATCH(Min); @@ -135,7 +135,7 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(And); ATTR_FUNCTOR_DISPATCH(Or); ATTR_FUNCTOR_DISPATCH(Not); - ATTR_FUNCTOR_DISPATCH(Cast); + ATTR_FUNCTOR_DISPATCH(CastNode); ATTR_FUNCTOR_DISPATCH(Call); ATTR_FUNCTOR_DISPATCH(Select); return vtable; @@ -160,11 +160,11 @@ class AttrsEqualHandler : bool VisitAttr_(const ir::UIntImm* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::FloatImm* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::StringImm* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::Add* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::Sub* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::Mul* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::Div* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::Mod* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::AddNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::SubNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::MulNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::DivNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::ModNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::FloorDiv* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::FloorMod* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::Min* lhs, const ObjectRef& other) final; @@ -178,7 +178,7 @@ class AttrsEqualHandler : bool VisitAttr_(const ir::And* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::Or* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::Not* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::Cast* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::CastNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::Call* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::Select* lhs, const ObjectRef& other) final; }; @@ -203,11 +203,11 @@ class AttrsHashHandler : size_t VisitAttr_(const ir::StringImm* lhs) final; size_t VisitAttr_(const ArrayNode* lhs) final; size_t VisitAttr_(const StrMapNode* lhs) final; - size_t VisitAttr_(const ir::Add* op) final; - size_t VisitAttr_(const ir::Sub* op) final; - size_t VisitAttr_(const ir::Mul* op) final; - size_t VisitAttr_(const ir::Div* op) final; - size_t VisitAttr_(const ir::Mod* op) final; + size_t VisitAttr_(const ir::AddNode* op) final; + size_t VisitAttr_(const ir::SubNode* op) final; + size_t VisitAttr_(const ir::MulNode* op) final; + size_t VisitAttr_(const ir::DivNode* op) final; + size_t VisitAttr_(const ir::ModNode* op) final; size_t VisitAttr_(const ir::FloorDiv* op) final; size_t VisitAttr_(const ir::FloorMod* op) final; size_t VisitAttr_(const ir::Min* op) final; @@ -221,7 +221,7 @@ class AttrsHashHandler : size_t VisitAttr_(const ir::And* op) final; size_t VisitAttr_(const ir::Or* op) final; size_t VisitAttr_(const ir::Not* op) final; - size_t VisitAttr_(const ir::Cast* op) final; + size_t VisitAttr_(const ir::CastNode* op) final; size_t VisitAttr_(const ir::Call* op) final; size_t VisitAttr_(const ir::Select* op) final; /*! diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index d69e3e2ad703..d17b2d905430 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -151,11 +151,11 @@ bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other } \ } \ -TVM_DEFINE_ATTRS_BINOP_EQUAL(Add); -TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub); -TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul); -TVM_DEFINE_ATTRS_BINOP_EQUAL(Div); -TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod); +TVM_DEFINE_ATTRS_BINOP_EQUAL(AddNode); +TVM_DEFINE_ATTRS_BINOP_EQUAL(SubNode); +TVM_DEFINE_ATTRS_BINOP_EQUAL(MulNode); +TVM_DEFINE_ATTRS_BINOP_EQUAL(DivNode); +TVM_DEFINE_ATTRS_BINOP_EQUAL(ModNode); TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDiv); TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorMod); TVM_DEFINE_ATTRS_BINOP_EQUAL(Max); @@ -177,8 +177,8 @@ bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const ObjectRef& other) { } } -bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const ObjectRef& other) { - if (const auto* rhs = other.as()) { +bool AttrsEqualHandler::VisitAttr_(const CastNode* lhs, const ObjectRef& other) { + if (const auto* rhs = other.as()) { if (lhs->dtype != rhs->dtype) return false; return Equal(lhs->value, rhs->value); } else { @@ -265,11 +265,11 @@ size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) { return Combine(key, Combine(Hash(op->a), Hash(op->b))); \ } \ -TVM_DEFINE_ATTRS_BINOP_HASH(Add); -TVM_DEFINE_ATTRS_BINOP_HASH(Sub); -TVM_DEFINE_ATTRS_BINOP_HASH(Mul); -TVM_DEFINE_ATTRS_BINOP_HASH(Div); -TVM_DEFINE_ATTRS_BINOP_HASH(Mod); +TVM_DEFINE_ATTRS_BINOP_HASH(AddNode); +TVM_DEFINE_ATTRS_BINOP_HASH(SubNode); +TVM_DEFINE_ATTRS_BINOP_HASH(MulNode); +TVM_DEFINE_ATTRS_BINOP_HASH(DivNode); +TVM_DEFINE_ATTRS_BINOP_HASH(ModNode); TVM_DEFINE_ATTRS_BINOP_HASH(FloorDiv); TVM_DEFINE_ATTRS_BINOP_HASH(FloorMod); TVM_DEFINE_ATTRS_BINOP_HASH(Max); @@ -288,8 +288,8 @@ size_t AttrsHashHandler::VisitAttr_(const Not* op) { return Combine(key, Hash(op->a)); } -size_t AttrsHashHandler::VisitAttr_(const Cast* op) { - static size_t key = std::hash()(Cast::_type_key); +size_t AttrsHashHandler::VisitAttr_(const CastNode* op) { + static size_t key = std::hash()(CastNode::_type_key); AttrsHash hasher; size_t res = key; res = Combine(res, hasher(op->dtype)); diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 22efa1dcc1bf..33c6a707fc79 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -65,7 +65,7 @@ inline std::vector ExprSplitAddition(const Expr &expr) { while (!split_buffer.empty()) { const Expr* top_ele = split_buffer.top(); split_buffer.pop(); - auto expr_add_match = top_ele->as(); + auto expr_add_match = top_ele->as(); if (expr_add_match) { split_buffer.push(&expr_add_match->b); split_buffer.push(&expr_add_match->a); @@ -88,13 +88,13 @@ inline std::pair MergeMulModInner(const Expr &mult_expr, const Expr &mod_l_expr, const Expr &mod_r_expr) { using namespace ir; - const Mul* mult_ptr = mult_expr.as(); + const MulNode* mult_ptr = mult_expr.as(); if (!mult_ptr) return std::make_pair(false, Expr()); Expr mult_outer = mult_ptr->b; const Expr* inner = &(mult_ptr->a); // 1. Calculate the outer multiplier while (true) { - mult_ptr = inner->as(); + mult_ptr = inner->as(); if (mult_ptr) { inner = &(mult_ptr->a); mult_outer = mult_ptr->b * mult_outer; @@ -113,8 +113,8 @@ inline std::pair MergeMulModInner(const Expr &mult_expr, Expr no_opt_sum; // Sum of the exprs that cannot be optimized while (true) { auto inner_div_ptr = search_ptr->as(); - auto inner_mult_ptr = search_ptr->as(); - auto inner_add_ptr = search_ptr->as(); + auto inner_mult_ptr = search_ptr->as(); + auto inner_add_ptr = search_ptr->as(); if (!inner_div_ptr && !inner_mult_ptr && !inner_add_ptr) { return std::make_pair(false, Expr()); } else if (inner_div_ptr) { @@ -160,7 +160,7 @@ inline void MergeMulModInsertElements(const std::vector& eles, *has_mod = false; for (const Expr* ele : eles) { auto mod_ptr = ele->as(); - auto mult_ptr = ele->as(); + auto mult_ptr = ele->as(); if (mod_ptr) { *has_mod = true; mod_exprs->emplace_back(std::make_pair(std::move(mod_ptr->a), std::move(mod_ptr->b))); @@ -299,7 +299,7 @@ Expr Buffer::vload(Array begin, DataType dtype) const { << "Cannot load " << dtype << " from buffer of " << n->dtype; if (dtype == DataType::Bool()) { - return ir::Cast::make( + return ir::CastNode::make( DataType::Bool(), ir::Load::make( DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)), @@ -321,7 +321,7 @@ Stmt Buffer::vstore(Array begin, Expr value) const { << " from buffer of " << n->dtype; if (value.dtype() == DataType::Bool()) { return ir::Store::make(n->data, - ir::Cast::make(DataType::Int(8), value), + ir::CastNode::make(DataType::Int(8), value), BufferOffset(n, begin, DataType::Int(8)), const_true()); } else { @@ -391,7 +391,7 @@ Expr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, E int highest_dim = 0; extent = self->strides[highest_dim] * self->shape[highest_dim] - offset; } else { - extent = arith::ComputeReduce(self->shape, Expr()) - offset; + extent = arith::ComputeReduce(self->shape, Expr()) - offset; } Expr elem_offset = self->elem_offset + offset; if (content_lanes > 1) { diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 1166e7eef976..7c8a791e7c24 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -32,7 +32,7 @@ namespace tvm { // simple cast that only checks if type matches and cast inline Expr SimpleCast(const DataType& t, Expr value) { if (value.dtype() == t) return value; - return ir::Cast::make(t, value); + return ir::CastNode::make(t, value); } // The public function with a quick checking path. @@ -176,7 +176,7 @@ Expr cast(const DataType& t, Expr value) { } else if (const FloatImm* op = value.as()) { return make_const(t, op->value); } - return ir::Cast::make(t, value); + return ir::CastNode::make(t, value); } else { if (value.dtype().lanes() == 1) { // manually unroll cast @@ -189,13 +189,13 @@ Expr cast(const DataType& t, Expr value) { } else if (const FloatImm* op = value.as()) { value = make_const(vtype, op->value); } else { - value = ir::Cast::make(vtype, value); + value = ir::CastNode::make(vtype, value); } } return ir::Broadcast::make(value, t.lanes()); } else { CHECK(value.dtype().lanes() == t.lanes()); - return ir::Cast::make(t, value); + return ir::CastNode::make(t, value); } } } @@ -207,9 +207,9 @@ Expr reinterpret(const DataType& t, Expr value) { Expr operator+(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::Add::make(a, b); + return ir::AddNode::make(a, b); } // negation @@ -225,23 +225,23 @@ Expr operator-(Expr a) { Expr operator-(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::Sub::make(a, b); + return ir::SubNode::make(a, b); } Expr operator*(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::Mul::make(a, b); + return ir::MulNode::make(a, b); } Expr div(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::Div::make(a, b); + return ir::DivNode::make(a, b); } Expr truncdiv(Expr a, Expr b) { @@ -252,9 +252,9 @@ Expr truncdiv(Expr a, Expr b) { Expr truncmod(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::Mod::make(a, b); + return ir::ModNode::make(a, b); } Expr operator/(Expr a, Expr b) { @@ -528,7 +528,7 @@ Expr isnan(Expr x) { Expr sum(Expr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::Add::make(x, y); + Expr result = ir::AddNode::make(x, y); Expr identity_element = make_zero(source.dtype()); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); @@ -575,7 +575,7 @@ Expr min(Expr source, Array rdom) { Expr prod(Expr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::Mul::make(x, y); + Expr result = ir::MulNode::make(x, y); Expr identity_element = make_const(source.dtype(), 1); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); diff --git a/src/lang/ir.cc b/src/lang/ir.cc index de047f330630..cda2d3b49997 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -56,10 +56,10 @@ Expr StringImm::make(std::string value) { return Expr(node); } -Expr Cast::make(DataType t, Expr value) { +Expr CastNode::make(DataType t, Expr value) { CHECK(value.defined()); CHECK_EQ(t.lanes(), value.dtype().lanes()); - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->dtype = t; node->value = std::move(value); return Expr(node); @@ -593,8 +593,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << op->dtype << '('; p->Print(op->value); p->stream << ')'; @@ -605,40 +605,40 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) // stream << op->name << "." << op->type; p->stream << op->name_hint; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " + "; p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " - "; p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << "*"; p->Print(op->b); p->stream << ')'; }) -.set_dispatch
([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << "/"; p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " % "; @@ -1155,13 +1155,13 @@ TVM_REGISTER_NODE_TYPE(FloatImm); TVM_REGISTER_NODE_TYPE(IntImm); TVM_REGISTER_NODE_TYPE(UIntImm); TVM_REGISTER_NODE_TYPE(StringImm); -TVM_REGISTER_NODE_TYPE(Cast); +TVM_REGISTER_NODE_TYPE(CastNode); TVM_REGISTER_NODE_TYPE(Variable); -TVM_REGISTER_NODE_TYPE(Add); -TVM_REGISTER_NODE_TYPE(Sub); -TVM_REGISTER_NODE_TYPE(Mul); -TVM_REGISTER_NODE_TYPE(Div); -TVM_REGISTER_NODE_TYPE(Mod); +TVM_REGISTER_NODE_TYPE(AddNode); +TVM_REGISTER_NODE_TYPE(SubNode); +TVM_REGISTER_NODE_TYPE(MulNode); +TVM_REGISTER_NODE_TYPE(DivNode); +TVM_REGISTER_NODE_TYPE(ModNode); TVM_REGISTER_NODE_TYPE(FloorDiv); TVM_REGISTER_NODE_TYPE(FloorMod); TVM_REGISTER_NODE_TYPE(Min); diff --git a/src/pass/bound_checker.cc b/src/pass/bound_checker.cc index d3898a2ecac6..3166df4d9fac 100644 --- a/src/pass/bound_checker.cc +++ b/src/pass/bound_checker.cc @@ -122,12 +122,12 @@ class BoundChecker : public StmtExprMutator { } // Scalarize the shape. - Expr shape = Mul::make(make_const(DataType::UInt(64), type.lanes()), - Cast::make(DataType::UInt(64), new_shape[0])); + Expr shape = MulNode::make(make_const(DataType::UInt(64), type.lanes()), + CastNode::make(DataType::UInt(64), new_shape[0])); for (size_t i = 1; i < new_shape.size(); ++i) { // Cast to unsigned to avoid integer overlow at frist. - shape = Mul::make(shape, Mul::make(make_const(DataType::UInt(64), type.lanes()), - Cast::make(DataType::UInt(64), new_shape[i]))); + shape = MulNode::make(shape, MulNode::make(make_const(DataType::UInt(64), type.lanes()), + CastNode::make(DataType::UInt(64), new_shape[i]))); } mem_to_shape_[buffer_var.get()] = shape; } @@ -166,9 +166,9 @@ class BoundChecker : public StmtExprMutator { if (const Ramp *ramp_index = index.as()) { // In case index is base + stride * i. // Non inclusive range. - index = Add::make( + index = AddNode::make( ramp_index->base, - Mul::make(ramp_index->stride, make_const(ramp_index->stride.dtype(), + MulNode::make(ramp_index->stride, make_const(ramp_index->stride.dtype(), ramp_index->lanes - 1))); } @@ -177,8 +177,8 @@ class BoundChecker : public StmtExprMutator { upper_bound = ir::Simplify(upper_bound); // Cast to the same type - signed, to be able to check lower bound. - index = Cast::make(DataType::Int(64), index); - upper_bound = Cast::make(DataType::Int(64), upper_bound); + index = CastNode::make(DataType::Int(64), index); + upper_bound = CastNode::make(DataType::Int(64), upper_bound); // Looks like a lower bound should always be zero after normalization. Expr lower_bound = make_zero(DataType::Int(64)); diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc index d1ba19b9fb05..6ccf393e6a99 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/pass/inject_copy_intrin.cc @@ -74,7 +74,7 @@ class CopyIntrinInjector : public StmtMutator { if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) || select(sel_cond, sel_true_value, sel_false_value).Match(store->value); - const Cast* cast = store->value.as(); + const CastNode* cast = store->value.as(); const Load* load = store->value.as(); if (0 == loops.size()) { CHECK(!has_cond); diff --git a/src/pass/inject_double_buffer.cc b/src/pass/inject_double_buffer.cc index 0158a949da53..b9028e02bde7 100644 --- a/src/pass/inject_double_buffer.cc +++ b/src/pass/inject_double_buffer.cc @@ -98,7 +98,7 @@ class DoubleBufferInjector : public StmtExprMutator { Stmt VisitStmt_(const Allocate* op) final { auto it = dbuffer_info_.find(op->buffer_var.get()); if (it != dbuffer_info_.end()) { - it->second.stride = arith::ComputeReduce( + it->second.stride = arith::ComputeReduce( op->extents, Expr()) * op->dtype.lanes(); Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc index 0887a83c1a48..0a3a8d621f97 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/pass/inject_virtual_thread.cc @@ -395,7 +395,7 @@ class VTInjector : public StmtExprMutator { // always rewrite if not allow sharing. if (touched_var_.count(op->buffer_var.get()) || !allow_share_) { // place v on highest dimension. - Expr stride = arith::ComputeReduce( + Expr stride = arith::ComputeReduce( op->extents, Expr()) * op->dtype.lanes(); Array other; other.push_back(make_const(op->extents[0].dtype(), num_threads_)); diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc index bbee9eeb7c8a..62cd9ab629cf 100644 --- a/src/pass/ir_deep_compare.cc +++ b/src/pass/ir_deep_compare.cc @@ -264,8 +264,8 @@ class IRDeepCompare : CompareString(op->value, other.as()->value); } - void VisitExpr_(const Cast *op, const Expr& other) final { - CompareExpr(op->value, other.as()->value); + void VisitExpr_(const CastNode *op, const Expr& other) final { + CompareExpr(op->value, other.as()->value); } void VisitExpr_(const Not *op, const Expr& other) final { @@ -298,11 +298,11 @@ class IRDeepCompare : if (CompareArray(op->indices, rhs->indices) != 0) return; } - DEFINE_BIOP_EXPR_CMP_(Add) - DEFINE_BIOP_EXPR_CMP_(Sub) - DEFINE_BIOP_EXPR_CMP_(Mul) - DEFINE_BIOP_EXPR_CMP_(Div) - DEFINE_BIOP_EXPR_CMP_(Mod) + DEFINE_BIOP_EXPR_CMP_(AddNode) + DEFINE_BIOP_EXPR_CMP_(SubNode) + DEFINE_BIOP_EXPR_CMP_(MulNode) + DEFINE_BIOP_EXPR_CMP_(DivNode) + DEFINE_BIOP_EXPR_CMP_(ModNode) DEFINE_BIOP_EXPR_CMP_(FloorDiv) DEFINE_BIOP_EXPR_CMP_(FloorMod) DEFINE_BIOP_EXPR_CMP_(Min) diff --git a/src/pass/ir_functor.cc b/src/pass/ir_functor.cc index dddf90eb47aa..27305917bd81 100644 --- a/src/pass/ir_functor.cc +++ b/src/pass/ir_functor.cc @@ -241,11 +241,11 @@ void ExprVisitor::VisitExpr_(const Call* op) { this->VisitExpr(op->b); \ } -DEFINE_BINOP_VISIT_(Add); -DEFINE_BINOP_VISIT_(Sub); -DEFINE_BINOP_VISIT_(Mul); -DEFINE_BINOP_VISIT_(Div); -DEFINE_BINOP_VISIT_(Mod); +DEFINE_BINOP_VISIT_(AddNode); +DEFINE_BINOP_VISIT_(SubNode); +DEFINE_BINOP_VISIT_(MulNode); +DEFINE_BINOP_VISIT_(DivNode); +DEFINE_BINOP_VISIT_(ModNode); DEFINE_BINOP_VISIT_(FloorDiv); DEFINE_BINOP_VISIT_(FloorMod); DEFINE_BINOP_VISIT_(Min); @@ -273,7 +273,7 @@ void ExprVisitor::VisitExpr_(const Reduce* op) { this->VisitExpr(op->condition); } -void ExprVisitor::VisitExpr_(const Cast* op) { +void ExprVisitor::VisitExpr_(const CastNode* op) { this->VisitExpr(op->value); } @@ -656,11 +656,11 @@ DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm) } \ } -DEFINE_BIOP_EXPR_MUTATE_(Add); -DEFINE_BIOP_EXPR_MUTATE_(Sub); -DEFINE_BIOP_EXPR_MUTATE_(Mul); -DEFINE_BIOP_EXPR_MUTATE_(Div); -DEFINE_BIOP_EXPR_MUTATE_(Mod); +DEFINE_BIOP_EXPR_MUTATE_(AddNode); +DEFINE_BIOP_EXPR_MUTATE_(SubNode); +DEFINE_BIOP_EXPR_MUTATE_(MulNode); +DEFINE_BIOP_EXPR_MUTATE_(DivNode); +DEFINE_BIOP_EXPR_MUTATE_(ModNode); DEFINE_BIOP_EXPR_MUTATE_(FloorDiv); DEFINE_BIOP_EXPR_MUTATE_(FloorMod); DEFINE_BIOP_EXPR_MUTATE_(Min); @@ -705,12 +705,12 @@ Expr ExprMutator::VisitExpr_(const Reduce* op) { } } -Expr ExprMutator::VisitExpr_(const Cast* op) { +Expr ExprMutator::VisitExpr_(const CastNode* op) { Expr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); } else { - return Cast::make(op->dtype, value); + return CastNode::make(op->dtype, value); } } diff --git a/src/pass/lower_custom_datatypes.cc b/src/pass/lower_custom_datatypes.cc index 2440b1f64a65..3c0439a49c15 100644 --- a/src/pass/lower_custom_datatypes.cc +++ b/src/pass/lower_custom_datatypes.cc @@ -41,14 +41,14 @@ class CustomDatatypesLowerer : public StmtExprMutator { public: explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {} - inline Expr VisitExpr_(const Cast* op) final { + inline Expr VisitExpr_(const CastNode* op) final { auto type_code = op->dtype.code(); auto src_type_code = op->value.dtype().code(); // If either datatype is a registered custom datatype, we must lower. bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) || datatype::Registry::Global()->GetTypeRegistered(src_type_code); Expr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); + op = expr.as(); if (toBeLowered) { auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code); CHECK(lower) << "Cast lowering function for target " << target_ << " destination type " @@ -96,12 +96,12 @@ class CustomDatatypesLowerer : public StmtExprMutator { return expr; } -#define DEFINE_MUTATE__(OP) \ - inline Expr VisitExpr_(const OP* op) final { \ +#define DEFINE_MUTATE__(OP, NodeName) \ + inline Expr VisitExpr_(const NodeName* op) final { \ auto type_code = op->dtype.code(); \ bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ Expr expr = StmtExprMutator::VisitExpr_(op); \ - op = expr.as(); \ + op = expr.as(); \ if (toBeLowered) { \ auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ CHECK(lower) << #OP " lowering function for target " << target_ << " type " \ @@ -111,19 +111,19 @@ class CustomDatatypesLowerer : public StmtExprMutator { return expr; \ } - DEFINE_MUTATE__(Add) - DEFINE_MUTATE__(Sub) - DEFINE_MUTATE__(Mul) - DEFINE_MUTATE__(Div) - DEFINE_MUTATE__(Mod) - DEFINE_MUTATE__(Min) - DEFINE_MUTATE__(Max) - DEFINE_MUTATE__(EQ) - DEFINE_MUTATE__(NE) - DEFINE_MUTATE__(LT) - DEFINE_MUTATE__(LE) - DEFINE_MUTATE__(GT) - DEFINE_MUTATE__(GE) + DEFINE_MUTATE__(Add, AddNode); + DEFINE_MUTATE__(Sub, SubNode); + DEFINE_MUTATE__(Mul, MulNode); + DEFINE_MUTATE__(Div, DivNode); + DEFINE_MUTATE__(Mod, ModNode); + DEFINE_MUTATE__(Min, Min); + DEFINE_MUTATE__(Max, Max); + DEFINE_MUTATE__(EQ, EQ); + DEFINE_MUTATE__(NE, NE); + DEFINE_MUTATE__(LT, LT); + DEFINE_MUTATE__(LE, LE); + DEFINE_MUTATE__(GT, GT); + DEFINE_MUTATE__(GE, GE); // Later changes may need to add more mutate functions as we support workloads with more ops. private: diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index 0f4971022740..68b7253c401f 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -62,10 +62,10 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return IRMutatorWithAnalyzer::VisitExpr_(op); } - Expr VisitExpr_(const Add* op) final { - if (const Mul* mb = op->b.as()) { + Expr VisitExpr_(const AddNode* op) final { + if (const MulNode* mb = op->b.as()) { return MakeFMA(mb->a, mb->b, op->a, op); - } else if (const Mul* ma = op->a.as()) { + } else if (const MulNode* ma = op->a.as()) { return MakeFMA(ma->a, ma->b, op->b, op); } return IRMutatorWithAnalyzer::VisitExpr_(op); @@ -210,7 +210,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // instruction with the latter. For example, vmla vs. vmlal // on ARM. if (const Broadcast* bcast = e.as()) { - if (const Cast* cast = bcast->value.as()) { + if (const CastNode* cast = bcast->value.as()) { auto should_swap = [&]() { // Maintain behaviour (int8 -> int16, fp16 -> fp32). if (cast->dtype.bits() == cast->value.dtype().bits() * 2) { @@ -229,7 +229,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { if (should_swap()) { Expr new_bcast = Broadcast::make(cast->value, bcast->lanes); - return Cast::make(bcast->dtype, new_bcast); + return CastNode::make(bcast->dtype, new_bcast); } } } @@ -237,7 +237,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } Expr MakeFMA(const Expr& a, const Expr& b, const Expr& c, - const Add* op) { + const AddNode* op) { // emit fma instruction: a * b + c Expr lhs = SwapBroadcastCast(a); Expr rhs = SwapBroadcastCast(b); @@ -248,8 +248,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { if (r.defined()) return this->VisitExpr(r); } else { if (!lhs.same_as(a) || !rhs.same_as(b)) { - Expr mul = this->VisitExpr(Mul::make(lhs, rhs)); - return Add::make(mul, this->VisitExpr(c)); + Expr mul = this->VisitExpr(MulNode::make(lhs, rhs)); + return AddNode::make(mul, this->VisitExpr(c)); } } return IRMutatorWithAnalyzer::VisitExpr_(op); diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc index c0b98793c7f9..f573a49d2660 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/pass/lower_tvm_builtin.cc @@ -247,7 +247,7 @@ class BuiltinLower : public StmtExprMutator { DataType t = arg.dtype(); DataType api_type = APIType(t); if (t != api_type) { - arg = Cast::make(api_type, arg); + arg = CastNode::make(api_type, arg); } prep_seq_.emplace_back(TVMStructSet( stack_value_, static_cast(arg_stack_begin + i - 1), @@ -296,7 +296,7 @@ class BuiltinLower : public StmtExprMutator { DataType t = arg.dtype(); DataType api_type = APIType(t); if (t != api_type) { - arg = Cast::make(api_type, arg); + arg = CastNode::make(api_type, arg); } prep_seq_.emplace_back(TVMStructSet( stack_value_, static_cast(arg_stack_begin + i - 1), diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index f065502db6b4..048289e24710 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -76,7 +76,7 @@ LoweredFunc MakeAPI(Stmt body, Call::PureIntrinsic); // cast to the target version. if (api_type != t) { - res = Cast::make(t, res); + res = CastNode::make(t, res); } return res; }; diff --git a/src/pass/rewrite_unsafe_select.cc b/src/pass/rewrite_unsafe_select.cc index 0a276719514d..8886222d2dca 100644 --- a/src/pass/rewrite_unsafe_select.cc +++ b/src/pass/rewrite_unsafe_select.cc @@ -57,11 +57,11 @@ class UnsafeExprDetector : public ExprFunctor { // Load is considered unsafe. return true; } - bool VisitExpr_(const Add* op) final { return BinaryOp(op); } - bool VisitExpr_(const Sub* op) final { return BinaryOp(op); } - bool VisitExpr_(const Mul* op) final { return BinaryOp(op); } - bool VisitExpr_(const Div* op) final { return BinaryOp(op); } - bool VisitExpr_(const Mod* op) final { return BinaryOp(op); } + bool VisitExpr_(const AddNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const SubNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const MulNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const DivNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const ModNode* op) final { return BinaryOp(op); } bool VisitExpr_(const FloorDiv* op) final { return BinaryOp(op); } bool VisitExpr_(const FloorMod* op) final { return BinaryOp(op); } bool VisitExpr_(const Min* op) final { return BinaryOp(op); } @@ -80,7 +80,7 @@ class UnsafeExprDetector : public ExprFunctor { bool VisitExpr_(const Let* op) final { return VisitExpr(op->body) || VisitExpr(op->value); } - bool VisitExpr_(const Cast* op) final { + bool VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } bool VisitExpr_(const Broadcast* op) final { diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 6bb3fc5a6025..a073f60a7b49 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -495,10 +495,10 @@ class StorageFlattener : public StmtExprMutator { Expr MakeBound(const DataType &type, const Array &shape) { // We have already checked the shape size to be greater then 0. - Expr bound = Mul::make(make_const(shape[0].dtype(), type.lanes()), shape[0]); + Expr bound = MulNode::make(make_const(shape[0].dtype(), type.lanes()), shape[0]); for (size_t i = 1; i < shape.size(); ++i) { - bound = Mul::make( - bound, Mul::make(make_const(bound.dtype(), type.lanes()), shape[i])); + bound = MulNode::make( + bound, MulNode::make(make_const(bound.dtype(), type.lanes()), shape[i])); } return bound; } diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index c820c477e128..b4cb266bd9a9 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -577,7 +577,7 @@ class StoragePlanRewriter : public StmtExprMutator { } if (e->allocs.size() == 1) { // simply use the original allocation. - Expr sz = arith::ComputeReduce(e->allocs[0]->extents, + Expr sz = arith::ComputeReduce(e->allocs[0]->extents, make_const(DataType::Int(32), 1)); e->new_alloc = Allocate::make( e->alloc_var, alloc_type, {sz}, @@ -592,7 +592,7 @@ class StoragePlanRewriter : public StmtExprMutator { // Build a merged allocation Expr combo_size; for (const Allocate* op : e->allocs) { - Expr sz = arith::ComputeReduce(op->extents, make_const(DataType::Int(32), 1)); + Expr sz = arith::ComputeReduce(op->extents, make_const(DataType::Int(32), 1)); auto nbits = op->dtype.bits() * op->dtype.lanes(); if (const auto* imm = sz.as()) { if (imm->value > std::numeric_limits::max() / nbits) { diff --git a/src/pass/tensor_core.cc b/src/pass/tensor_core.cc index a3890cde773e..cc9bd2870531 100644 --- a/src/pass/tensor_core.cc +++ b/src/pass/tensor_core.cc @@ -60,7 +60,7 @@ std::string simplify_name(std::string input) { } Expr unpack_type_cast(const Expr &input, const DataType &target_type) { - auto cast = input.as(); + auto cast = input.as(); if (cast == nullptr) { return input; } else if (cast->dtype == target_type) { @@ -174,7 +174,7 @@ class MMAMatcher: public StmtVisitor { // Do the pattern matching bool mma_sync_match_(const Provide* op, BufferInfo store_buffer) { - auto* add = op->value.as(); + auto* add = op->value.as(); if (add == nullptr) { return false; } @@ -188,7 +188,7 @@ class MMAMatcher: public StmtVisitor { return false; } - auto mul = unpack_type_cast(add->b, buffer_c.dtype).as(); + auto mul = unpack_type_cast(add->b, buffer_c.dtype).as(); if (mul == nullptr) { return false; } @@ -239,13 +239,13 @@ class BodyVisitor : public StmtExprVisitor { BodyVisitor() {} void VisitExpr_(const Reduce* op) final { - auto* comm_add = op->combiner->result[0].as(); + auto* comm_add = op->combiner->result[0].as(); if (comm_add == nullptr || op->combiner->result.size() > 1) { return; } for (Expr source : op->source) { - auto mul_0 = unpack_type_cast(source, DataType::Float(32)).as(); - auto mul_1 = unpack_type_cast(source, DataType::Int(32)).as(); + auto mul_0 = unpack_type_cast(source, DataType::Float(32)).as(); + auto mul_1 = unpack_type_cast(source, DataType::Int(32)).as(); if (mul_0 == nullptr && mul_1 == nullptr) { continue; } @@ -464,7 +464,7 @@ class BufferAnalyser : public StmtExprVisitor { for (size_t i = 1; i < bi.shape.size(); ++i) { Expr stride = IntImm::make(DataType::Int(32), 1); for (size_t j = bi.shape.size() - 1; j >= i; --j) { - stride = Mul::make(stride, bi.shape[j]); + stride = MulNode::make(stride, bi.shape[j]); } strides.push_back(stride); } @@ -577,7 +577,7 @@ class BufferAnalyser : public StmtExprVisitor { for (size_t i = 1; i < bi.shape.size(); ++i) { Expr stride = IntImm::make(DataType::Int(32), 1); for (size_t j = bi.shape.size() - 1; j >= i; --j) { - stride = Mul::make(stride, bi.shape[j]); + stride = MulNode::make(stride, bi.shape[j]); } strides.push_back(stride); } @@ -769,8 +769,8 @@ class ThreadIdxMutator : public StmtExprMutator { return zero; } if (op->name_hint == "threadIdx.y") { - Expr div = Div::make(expr, warp_y_); - Expr mul = Mul::make(div, warp_y_); + Expr div = DivNode::make(expr, warp_y_); + Expr mul = MulNode::make(div, warp_y_); return mul; } } @@ -1091,7 +1091,7 @@ class TensorCoreIRMutator : public StmtExprMutator { for (size_t i = 1; i < shape.size(); ++i) { Expr stride = IntImm::make(DataType::Int(32), 1); for (size_t j = shape.size() - 1; j >= i; --j) { - stride = Mul::make(stride, shape[j]); + stride = MulNode::make(stride, shape[j]); } strides.push_back(stride); } @@ -1100,9 +1100,9 @@ class TensorCoreIRMutator : public StmtExprMutator { Expr elem_offset = IntImm::make(DataType::Int(32), 0); CHECK_EQ(call->args.size(), min_bound.size()); for (size_t i = 0; i < min_bound.size(); i++) { - elem_offset = Add::make( - elem_offset, Mul::make( - strides[i], Sub::make(call->args[i], min_bound[i]))); + elem_offset = AddNode::make( + elem_offset, MulNode::make( + strides[i], SubNode::make(call->args[i], min_bound[i]))); } auto it2 = matrix_abc_.find(simplify_name(call->name)); diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index c22243cc8f93..78066ef1b6ef 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -111,13 +111,13 @@ class Vectorizer : public StmtExprMutator { } } - Expr VisitExpr_(const Add* op) final { + Expr VisitExpr_(const AddNode* op) final { return AddSubVec(op); } - Expr VisitExpr_(const Sub* op) final { + Expr VisitExpr_(const SubNode* op) final { return AddSubVec(op); } - Expr VisitExpr_(const Mul* op) final { + Expr VisitExpr_(const MulNode* op) final { Expr a = this->VisitExpr(op->a); Expr b = this->VisitExpr(op->b); if (a.same_as(op->a) && @@ -137,14 +137,14 @@ class Vectorizer : public StmtExprMutator { b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes); } } - return Mul::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + return MulNode::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } return BinaryVec(op); } - Expr VisitExpr_(const Div* op) final { + Expr VisitExpr_(const DivNode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const Mod* op) final { + Expr VisitExpr_(const ModNode* op) final { return BinaryVec(op); } Expr VisitExpr_(const FloorDiv* op) final { @@ -219,12 +219,12 @@ class Vectorizer : public StmtExprMutator { return Select::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes)); } } - Expr VisitExpr_(const Cast *op) final { + Expr VisitExpr_(const CastNode *op) final { Expr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); } else { - return Cast::make(op->dtype.with_lanes(value.dtype().lanes()), value); + return CastNode::make(op->dtype.with_lanes(value.dtype().lanes()), value); } } // Variable diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 1f2a016f150f..73cb5a1c4e11 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -83,8 +83,8 @@ bool UpSamplingRel(const Array& types, << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(2, ir::Cast::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h))); - oshape.Set(3, ir::Cast::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w))); + oshape.Set(2, ir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h))); + oshape.Set(3, ir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w))); // assign output type reporter->Assign(types[1], @@ -162,9 +162,9 @@ bool UpSampling3DRel(const Array& types, << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(2, ir::Cast::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d))); - oshape.Set(3, ir::Cast::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h))); - oshape.Set(4, ir::Cast::make(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w))); + oshape.Set(2, ir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d))); + oshape.Set(3, ir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h))); + oshape.Set(4, ir::CastNode::make(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w))); // assign output type reporter->Assign(types[1], diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index eaaec172f971..fd3708f70240 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -34,7 +34,7 @@ TEST(IRF, Basic) { f.set_dispatch([](const ObjectRef& n, int b) { return b; }); - f.set_dispatch([](const ObjectRef& n, int b) { + f.set_dispatch([](const ObjectRef& n, int b) { return b + 2; }); CHECK_EQ(f(x, 2), 2); @@ -69,7 +69,7 @@ TEST(IRF, ExprTransform) { int VisitExpr_(const IntImm* op, int b) final { return op->value; } - int VisitExpr_(const Add* op, int b) final { + int VisitExpr_(const AddNode* op, int b) final { return VisitExpr(op->a, b) + VisitExpr(op->b, b); } }; @@ -152,7 +152,7 @@ TEST(IRF, StmtMutator) { protected: // implementation - Expr VisitExpr_(const Add* op) final { + Expr VisitExpr_(const AddNode* op) final { return op->a; } Stmt VisitStmt_(const SeqStmtNode* op) final { @@ -191,7 +191,7 @@ TEST(IRF, StmtMutator) { // copy because there is additional refs CHECK(!arr[0].as()->body.same_as(bref)); CHECK(arr[0].as()->body.as()->value.same_as(x)); - CHECK(bref.as()->value.as()); + CHECK(bref.as()->value.as()); } { Array arr{fmakealloc()}; diff --git a/tests/cpp/ir_simplify_test.cc b/tests/cpp/ir_simplify_test.cc index 57d7d5041e6d..6b694eff28f3 100644 --- a/tests/cpp/ir_simplify_test.cc +++ b/tests/cpp/ir_simplify_test.cc @@ -46,7 +46,7 @@ TEST(IRSIMPLIFY, Mod) { // Mod::make is used instead of % to avoid constant folding during // calling operator%(x,y). Mod::make doesn't try constant folding, // and therefore, the constant folding will be attempted in CanonicalSimplify - auto mod = tvm::ir::CanonicalSimplify(tvm::ir::Mod::make(x, y)); + auto mod = tvm::ir::CanonicalSimplify(tvm::ir::ModNode::make(x, y)); auto es = tvm::ir::CanonicalSimplify(mod - x); CHECK(is_zero(es)); } diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 9710428d1b13..c4c274ee064b 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -100,13 +100,13 @@ TEST(Pattern, Basic) { // cast pattern { CHECK(!cast(PConst( - DataType::Int(32)), px).Match(ir::Cast::make(DataType::Float(64), x))); - CHECK(cast(pt, px).Match(ir::Cast::make(DataType::Float(64), x))); + DataType::Int(32)), px).Match(ir::CastNode::make(DataType::Float(64), x))); + CHECK(cast(pt, px).Match(ir::CastNode::make(DataType::Float(64), x))); CHECK(pt.Eval() == DataType::Float(64)); auto zz = cast(pt, px).Eval(); CHECK((cast(pt, px) - cast(pt, py)).Match( - ir::Cast::make(DataType::Float(64), x) - ir::Cast::make(DataType::Int(64), x))); - auto expr = ir::Cast::make(DataType::Int(32), ir::Cast::make(DataType::Float(64), x)); + ir::CastNode::make(DataType::Float(64), x) - ir::CastNode::make(DataType::Int(64), x))); + auto expr = ir::CastNode::make(DataType::Int(32), ir::CastNode::make(DataType::Float(64), x)); CHECK(!(cast(pt, cast(pt, px))).Match(expr)); } // ramp pattern From bab21618f2092396e088ec9ecee5ff9c296c369c Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 7 Jan 2020 14:58:56 -0800 Subject: [PATCH 03/11] [REFACTOR][IR] Min/Max/FloorDiv/FloorMod -> MinNode/MaxNode etc. --- include/tvm/ir.h | 8 +++--- include/tvm/ir_functor_ext.h | 32 ++++++++++++------------ src/api/api_ir.cc | 8 +++--- src/arithmetic/canonical_simplify.cc | 16 ++++++------ src/arithmetic/compute_expr.h | 4 +-- src/arithmetic/const_fold.h | 8 +++--- src/arithmetic/const_int_bound.cc | 8 +++--- src/arithmetic/detect_linear_equation.cc | 4 +-- src/arithmetic/int_set.cc | 16 ++++++------ src/arithmetic/modular_set.cc | 6 ++--- src/arithmetic/pattern_match.h | 8 +++--- src/arithmetic/rewrite_simplify.cc | 24 +++++++++--------- src/arithmetic/rewrite_simplify.h | 8 +++--- src/codegen/codegen_c.cc | 4 +-- src/codegen/codegen_c.h | 4 +-- src/codegen/codegen_c_host.cc | 4 +-- src/codegen/codegen_c_host.h | 4 +-- src/codegen/codegen_vhls.cc | 4 +-- src/codegen/codegen_vhls.h | 8 +++--- src/codegen/llvm/codegen_cpu.cc | 4 +-- src/codegen/llvm/codegen_llvm.cc | 4 +-- src/codegen/llvm/codegen_llvm.h | 4 +-- src/codegen/spirv/codegen_spirv.cc | 4 +-- src/codegen/spirv/codegen_spirv.h | 4 +-- src/codegen/stackvm/codegen_stackvm.cc | 4 +-- src/codegen/stackvm/codegen_stackvm.h | 4 +-- src/contrib/hybrid/codegen_hybrid.cc | 8 +++--- src/contrib/hybrid/codegen_hybrid.h | 8 +++--- src/lang/attr_functor.h | 32 ++++++++++++------------ src/lang/attrs.cc | 16 ++++++------ src/lang/buffer.cc | 4 +-- src/lang/expr_operator.cc | 20 +++++++-------- src/lang/ir.cc | 24 +++++++++--------- src/pass/inject_copy_intrin.cc | 4 +-- src/pass/ir_deep_compare.cc | 8 +++--- src/pass/ir_functor.cc | 16 ++++++------ src/pass/loop_partition.cc | 4 +-- src/pass/lower_custom_datatypes.cc | 4 +-- src/pass/lower_intrin.cc | 10 ++++---- src/pass/rewrite_unsafe_select.cc | 8 +++--- src/pass/vectorize_loop.cc | 8 +++--- tests/cpp/expr_test.cc | 2 +- topi/include/topi/nn/pooling.h | 26 +++++++++---------- 43 files changed, 204 insertions(+), 204 deletions(-) diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 580c9b0b2455..7b1127a047d9 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -179,25 +179,25 @@ class ModNode : public BinaryOpNode { }; /*! \brief Floor division, floor(a/b) */ -class FloorDiv : public BinaryOpNode { +class FloorDivNode : public BinaryOpNode { public: static constexpr const char* _type_key = "FloorDiv"; }; /*! \brief The remainder of the floordiv */ -class FloorMod : public BinaryOpNode { +class FloorModNode : public BinaryOpNode { public: static constexpr const char* _type_key = "FloorMod"; }; /*! \brief min(a, b) */ -class Min : public BinaryOpNode { +class MinNode : public BinaryOpNode { public: static constexpr const char* _type_key = "Min"; }; /*! \brief max(a, b) */ -class Max : public BinaryOpNode { +class MaxNode : public BinaryOpNode { public: static constexpr const char* _type_key = "Max"; }; diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index 19fc345290ad..eb0601be4075 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -141,10 +141,10 @@ class ExprFunctor { virtual R VisitExpr_(const MulNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const DivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const FloorDiv* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const FloorMod* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Max* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FloorDivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FloorModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const MinNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const MaxNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const NE* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LT* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -183,10 +183,10 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(MulNode); IR_EXPR_FUNCTOR_DISPATCH(DivNode); IR_EXPR_FUNCTOR_DISPATCH(ModNode); - IR_EXPR_FUNCTOR_DISPATCH(FloorDiv); - IR_EXPR_FUNCTOR_DISPATCH(FloorMod); - IR_EXPR_FUNCTOR_DISPATCH(Min); - IR_EXPR_FUNCTOR_DISPATCH(Max); + IR_EXPR_FUNCTOR_DISPATCH(FloorDivNode); + IR_EXPR_FUNCTOR_DISPATCH(FloorModNode); + IR_EXPR_FUNCTOR_DISPATCH(MinNode); + IR_EXPR_FUNCTOR_DISPATCH(MaxNode); IR_EXPR_FUNCTOR_DISPATCH(EQ); IR_EXPR_FUNCTOR_DISPATCH(NE); IR_EXPR_FUNCTOR_DISPATCH(LT); @@ -307,10 +307,10 @@ class TVM_DLL ExprVisitor : void VisitExpr_(const MulNode* op) override; void VisitExpr_(const DivNode* op) override; void VisitExpr_(const ModNode* op) override; - void VisitExpr_(const FloorDiv* op) override; - void VisitExpr_(const FloorMod* op) override; - void VisitExpr_(const Min* op) override; - void VisitExpr_(const Max* op) override; + void VisitExpr_(const FloorDivNode* op) override; + void VisitExpr_(const FloorModNode* op) override; + void VisitExpr_(const MinNode* op) override; + void VisitExpr_(const MaxNode* op) override; void VisitExpr_(const EQ* op) override; void VisitExpr_(const NE* op) override; void VisitExpr_(const LT* op) override; @@ -352,10 +352,10 @@ class TVM_DLL ExprMutator : Expr VisitExpr_(const MulNode* op) override; Expr VisitExpr_(const DivNode* op) override; Expr VisitExpr_(const ModNode* op) override; - Expr VisitExpr_(const FloorDiv* op) override; - Expr VisitExpr_(const FloorMod* op) override; - Expr VisitExpr_(const Min* op) override; - Expr VisitExpr_(const Max* op) override; + Expr VisitExpr_(const FloorDivNode* op) override; + Expr VisitExpr_(const FloorModNode* op) override; + Expr VisitExpr_(const MinNode* op) override; + Expr VisitExpr_(const MaxNode* op) override; Expr VisitExpr_(const EQ* op) override; Expr VisitExpr_(const NE* op) override; Expr VisitExpr_(const LT* op) override; diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 6d271f6b77ca..9d5bd255a6d1 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -139,10 +139,10 @@ REGISTER_MAKE(SubNode); REGISTER_MAKE(MulNode); REGISTER_MAKE(DivNode); REGISTER_MAKE(ModNode); -REGISTER_MAKE(FloorDiv); -REGISTER_MAKE(FloorMod); -REGISTER_MAKE(Min); -REGISTER_MAKE(Max); +REGISTER_MAKE(FloorDivNode); +REGISTER_MAKE(FloorModNode); +REGISTER_MAKE(MinNode); +REGISTER_MAKE(MaxNode); REGISTER_MAKE(EQ); REGISTER_MAKE(NE); REGISTER_MAKE(LT); diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 1ede2453198d..519b19c311d4 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -455,8 +455,8 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { Expr VisitExpr_(const MulNode* op) final; Expr VisitExpr_(const DivNode* op) final; Expr VisitExpr_(const ModNode* op) final; - Expr VisitExpr_(const FloorDiv* op) final; - Expr VisitExpr_(const FloorMod* op) final; + Expr VisitExpr_(const FloorDivNode* op) final; + Expr VisitExpr_(const FloorModNode* op) final; Expr VisitExpr_(const Reduce* op) final; private: @@ -787,7 +787,7 @@ VisitExpr_(const DivNode* op) { } Expr CanonicalSimplifier::Impl:: -VisitExpr_(const FloorDiv* op) { +VisitExpr_(const FloorDivNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -795,7 +795,7 @@ VisitExpr_(const FloorDiv* op) { Expr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + Expr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; // x / c1 @@ -838,7 +838,7 @@ VisitExpr_(const FloorDiv* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return FloorDiv::make(a, b); + return FloorDivNode::make(a, b); } } @@ -963,7 +963,7 @@ VisitExpr_(const ModNode* op) { } Expr CanonicalSimplifier::Impl:: -VisitExpr_(const FloorMod* op) { +VisitExpr_(const FloorModNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -972,7 +972,7 @@ VisitExpr_(const FloorMod* op) { Expr b = this->CanonicalMutate(op->b); // const folding - Expr const_res = TryConstFold(a, b); + Expr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; @@ -1018,7 +1018,7 @@ VisitExpr_(const FloorMod* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return FloorMod::make(a, b); + return FloorModNode::make(a, b); } } diff --git a/src/arithmetic/compute_expr.h b/src/arithmetic/compute_expr.h index 36571078f0b1..aca26e85375a 100644 --- a/src/arithmetic/compute_expr.h +++ b/src/arithmetic/compute_expr.h @@ -102,12 +102,12 @@ inline Expr Compute(Expr a, Expr b) { } template<> -inline Expr Compute(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return max(a, b); } template<> -inline Expr Compute(Expr a, Expr b) { +inline Expr Compute(Expr a, Expr b) { return min(a, b); } diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index dbbbdcfbecda..1ca01d75b7ee 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -199,7 +199,7 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -226,7 +226,7 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -244,7 +244,7 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value)); @@ -255,7 +255,7 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value)); diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index 9a67a1b759c2..06e437e2ed82 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -215,7 +215,7 @@ class ConstIntBoundAnalyzer::Impl : } } - Entry VisitExpr_(const FloorDiv* op) final { + Entry VisitExpr_(const FloorDivNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); CHECK(!b.is_const(0)) << "floordiv by zero"; @@ -225,7 +225,7 @@ class ConstIntBoundAnalyzer::Impl : return BinaryOpBoundry(a, b, InfAwareFloorDiv); } - Entry VisitExpr_(const FloorMod* op) final { + Entry VisitExpr_(const FloorModNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); if (b.min_value > 0) { @@ -246,7 +246,7 @@ class ConstIntBoundAnalyzer::Impl : } } - Entry VisitExpr_(const Min* op) final { + Entry VisitExpr_(const MinNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); Entry ret; @@ -255,7 +255,7 @@ class ConstIntBoundAnalyzer::Impl : return ret; } - Entry VisitExpr_(const Max* op) final { + Entry VisitExpr_(const MaxNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); Entry ret; diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index b9a9a1ecb77e..742e24c332da 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -210,7 +210,7 @@ bool DetectClipBound( if (is_const_int(ret.coeff, 1)) { // var + shift >=0 -> var >= -shift if (p.min_value.defined()) { - p.min_value = ir::Max::make(p.min_value, -ret.base); + p.min_value = ir::MaxNode::make(p.min_value, -ret.base); } else { p.min_value = -ret.base; } @@ -219,7 +219,7 @@ bool DetectClipBound( if (is_const_int(ret.coeff, -1)) { // -var + shift >=0 -> var <= shift if (p.max_value.defined()) { - p.max_value = ir::Min::make(p.max_value, ret.base); + p.max_value = ir::MinNode::make(p.max_value, ret.base); } else { p.max_value = ret.base; } diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 86cb1bf622c3..8cb52271b833 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -256,7 +256,7 @@ inline IntervalSet Combine(Analyzer* analyzer, template<> -inline IntervalSet Combine(Analyzer* analyzer, +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -291,7 +291,7 @@ inline IntervalSet Combine(Analyzer* analyzer, } template<> -inline IntervalSet Combine(Analyzer* analyzer, +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -317,7 +317,7 @@ inline IntervalSet Combine(Analyzer* analyzer, } template<> -inline IntervalSet Combine(Analyzer* analzyer, +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -330,7 +330,7 @@ inline IntervalSet Combine(Analyzer* analzyer, } template<> -inline IntervalSet Combine(Analyzer* analzyer, +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { @@ -425,19 +425,19 @@ class IntervalSetEvaluator : return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const FloorDiv* op) final { + IntervalSet VisitExpr_(const FloorDivNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const FloorMod* op) final { + IntervalSet VisitExpr_(const FloorModNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const Min* op) final { + IntervalSet VisitExpr_(const MinNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const Max* op) final { + IntervalSet VisitExpr_(const MaxNode* op) final { return VisitBinaryExpr_(op); } diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 37fa30debc2d..09e9ee31ab44 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -196,7 +196,7 @@ class ModularSetAnalyzer::Impl : return Everything(); } - Entry VisitExpr_(const FloorDiv* op) final { + Entry VisitExpr_(const FloorDivNode* op) final { Entry b = VisitExpr(op->b); if (b.is_const()) { return DivByConst(op->a, b.base, true); @@ -204,13 +204,13 @@ class ModularSetAnalyzer::Impl : return Everything(); } - Entry VisitExpr_(const Min* op) final { + Entry VisitExpr_(const MinNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); return Union(a, b); } - Entry VisitExpr_(const Max* op) final { + Entry VisitExpr_(const MaxNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); return Union(a, b); diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h index 7e20302ca8f8..6670326e8046 100644 --- a/src/arithmetic/pattern_match.h +++ b/src/arithmetic/pattern_match.h @@ -332,13 +332,13 @@ TVM_PATTERN_BINARY_OP_EX(operator%, ir::ModNode, DivAmbiguityError(a)); TVM_PATTERN_BINARY_OP(operator+, ir::AddNode); TVM_PATTERN_BINARY_OP(operator-, ir::SubNode); TVM_PATTERN_BINARY_OP(operator*, ir::MulNode); -TVM_PATTERN_BINARY_OP(min, ir::Min); -TVM_PATTERN_BINARY_OP(max, ir::Max); +TVM_PATTERN_BINARY_OP(min, ir::MinNode); +TVM_PATTERN_BINARY_OP(max, ir::MaxNode); TVM_PATTERN_BINARY_OP(div, ir::DivNode); TVM_PATTERN_BINARY_OP(truncdiv, ir::DivNode); TVM_PATTERN_BINARY_OP(truncmod, ir::ModNode); -TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv); -TVM_PATTERN_BINARY_OP(floormod, ir::FloorMod); +TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDivNode); +TVM_PATTERN_BINARY_OP(floormod, ir::FloorModNode); // logical expressions TVM_PATTERN_BINARY_OP(operator>, ir::GT); diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 9e81645e5a02..b8fec1fbc7bb 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -781,10 +781,10 @@ VisitExpr_(const ModNode* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const FloorDiv* op) { +VisitExpr_(const FloorDivNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1; @@ -925,10 +925,10 @@ VisitExpr_(const FloorDiv* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const FloorMod* op) { +VisitExpr_(const FloorModNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -995,10 +995,10 @@ VisitExpr_(const FloorMod* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Min* op) { +VisitExpr_(const MinNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1180,10 +1180,10 @@ VisitExpr_(const Min* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Max* op) { +VisitExpr_(const MaxNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression diff --git a/src/arithmetic/rewrite_simplify.h b/src/arithmetic/rewrite_simplify.h index 39f5cf4f9954..f3899d54e236 100644 --- a/src/arithmetic/rewrite_simplify.h +++ b/src/arithmetic/rewrite_simplify.h @@ -55,10 +55,10 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { Expr VisitExpr_(const MulNode* op) override; Expr VisitExpr_(const DivNode* op) override; Expr VisitExpr_(const ModNode* op) override; - Expr VisitExpr_(const FloorDiv* op) override; - Expr VisitExpr_(const FloorMod* op) override; - Expr VisitExpr_(const Min* op) override; - Expr VisitExpr_(const Max* op) override; + Expr VisitExpr_(const FloorDivNode* op) override; + Expr VisitExpr_(const FloorModNode* op) override; + Expr VisitExpr_(const MinNode* op) override; + Expr VisitExpr_(const MaxNode* op) override; Expr VisitExpr_(const EQ* op) override; Expr VisitExpr_(const NE* op) override; Expr VisitExpr_(const LT* op) override; diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 523a565a866d..7e6ce33b6da0 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -480,10 +480,10 @@ void CodeGenC::VisitExpr_(const DivNode* op, std::ostream& os) { // NOLINT(*) void CodeGenC::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "%", os, this); } -void CodeGenC::VisitExpr_(const Min* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "min", os, this); } -void CodeGenC::VisitExpr_(const Max* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "max", os, this); } void CodeGenC::VisitExpr_(const EQ* op, std::ostream& os) { // NOLINT(*) diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index 345e817e065f..8e06989baa53 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -111,8 +111,8 @@ class CodeGenC : void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*) diff --git a/src/codegen/codegen_c_host.cc b/src/codegen/codegen_c_host.cc index 506618200e29..ac5393af9b28 100644 --- a/src/codegen/codegen_c_host.cc +++ b/src/codegen/codegen_c_host.cc @@ -254,11 +254,11 @@ void CodeGenCHost::VisitStmt_(const AssertStmt *op) { // NOLINT(*) this->PrintStmt(op->body); } -void CodeGenCHost::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*) +void CodeGenCHost::VisitExpr_(const MinNode *op, std::ostream& os) { // NOLINT(*) PrintTernaryCondExpr(op, "<", os); } -void CodeGenCHost::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*) +void CodeGenCHost::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOLINT(*) PrintTernaryCondExpr(op, ">", os); } diff --git a/src/codegen/codegen_c_host.h b/src/codegen/codegen_c_host.h index 44f838536627..7d5ce58f1162 100644 --- a/src/codegen/codegen_c_host.h +++ b/src/codegen/codegen_c_host.h @@ -46,8 +46,8 @@ class CodeGenCHost final : public CodeGenC { void VisitExpr_(const Call *op, std::ostream& os) final; // NOLINT(*) // overload min and max to use the ternary operator, so we don't rely on the // standard library implementations - void VisitExpr_(const Min *op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const Max *op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const MinNode *op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const MaxNode *op, std::ostream& os) final; // NOLINT(*) void VisitStmt_(const AssertStmt *op) final; // NOLINT(*) diff --git a/src/codegen/codegen_vhls.cc b/src/codegen/codegen_vhls.cc index d12e54da1958..e7231a16c280 100644 --- a/src/codegen/codegen_vhls.cc +++ b/src/codegen/codegen_vhls.cc @@ -98,7 +98,7 @@ inline void PrintBinaryExpr(const T* op, os << ')'; } -void CodeGenVivadoHLS::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*) +void CodeGenVivadoHLS::VisitExpr_(const MinNode *op, std::ostream& os) { // NOLINT(*) const char *opstr = "std::min"; if (op->dtype.is_float()) { switch (op->dtype.bits()) { @@ -112,7 +112,7 @@ void CodeGenVivadoHLS::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT( PrintBinaryExpr(op, opstr, os, this); } -void CodeGenVivadoHLS::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*) +void CodeGenVivadoHLS::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOLINT(*) const char *opstr = "std::max"; if (op->dtype.is_float()) { switch (op->dtype.bits()) { diff --git a/src/codegen/codegen_vhls.h b/src/codegen/codegen_vhls.h index e678edb05198..e406cb56e40e 100644 --- a/src/codegen/codegen_vhls.h +++ b/src/codegen/codegen_vhls.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -38,8 +38,8 @@ class CodeGenVivadoHLS final : public CodeGenC { void PrintType(DataType t, std::ostream& os); void AddFunction(LoweredFunc f); void PreFunctionBody(LoweredFunc f); - void VisitExpr_(const Min *op, std::ostream& os); - void VisitExpr_(const Max *op, std::ostream& os); + void VisitExpr_(const MinNode *op, std::ostream& os); + void VisitExpr_(const MaxNode *op, std::ostream& os); }; } // namespace codegen diff --git a/src/codegen/llvm/codegen_cpu.cc b/src/codegen/llvm/codegen_cpu.cc index 9f1a2926f002..1949c8f69782 100644 --- a/src/codegen/llvm/codegen_cpu.cc +++ b/src/codegen/llvm/codegen_cpu.cc @@ -936,8 +936,8 @@ void CodeGenCPU::VisitStmt_(const For* op) { op->body); } else { Expr step = (op->extent + num_task - make_const(t, 1)) / num_task; - Expr begin = Min::make(task_id * step, op->extent); - Expr end = Min::make((task_id + make_const(t, 1)) * step, op->extent); + Expr begin = MinNode::make(task_id * step, op->extent); + Expr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent); CreateSerialFor(MakeValue(begin), MakeValue(end), ConstInt32(1), diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 6ea9d42eaa8d..5bcd66ba55fe 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -891,13 +891,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) { } } -llvm::Value* CodeGenLLVM::VisitExpr_(const Min* op) { +llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b); } -llvm::Value* CodeGenLLVM::VisitExpr_(const Max* op) { +llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b); diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index f1f48df78315..e2d4a4209f73 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -114,8 +114,8 @@ class CodeGenLLVM : llvm::Value* VisitExpr_(const MulNode* op) override; llvm::Value* VisitExpr_(const DivNode* op) override; llvm::Value* VisitExpr_(const ModNode* op) override; - llvm::Value* VisitExpr_(const Min* op) override; - llvm::Value* VisitExpr_(const Max* op) override; + llvm::Value* VisitExpr_(const MinNode* op) override; + llvm::Value* VisitExpr_(const MaxNode* op) override; llvm::Value* VisitExpr_(const LT* op) override; llvm::Value* VisitExpr_(const LE* op) override; llvm::Value* VisitExpr_(const GT* op) override; diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index cfc175ee0554..88f8e89647be 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -173,13 +173,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const ModNode* op) { return builder_->Mod(MakeValue(op->a), MakeValue(op->b)); } -spirv::Value CodeGenSPIRV::VisitExpr_(const Min* op) { +spirv::Value CodeGenSPIRV::VisitExpr_(const MinNode* op) { spirv::Value a = MakeValue(op->a); spirv::Value b = MakeValue(op->b); return builder_->Select(builder_->LT(a, b), a, b); } -spirv::Value CodeGenSPIRV::VisitExpr_(const Max* op) { +spirv::Value CodeGenSPIRV::VisitExpr_(const MaxNode* op) { spirv::Value a = MakeValue(op->a); spirv::Value b = MakeValue(op->b); return builder_->Select(builder_->GT(a, b), a, b); diff --git a/src/codegen/spirv/codegen_spirv.h b/src/codegen/spirv/codegen_spirv.h index 6839ed2a4635..07b305c1b812 100644 --- a/src/codegen/spirv/codegen_spirv.h +++ b/src/codegen/spirv/codegen_spirv.h @@ -73,8 +73,8 @@ class CodeGenSPIRV: spirv::Value VisitExpr_(const MulNode* op) override; spirv::Value VisitExpr_(const DivNode* op) override; spirv::Value VisitExpr_(const ModNode* op) override; - spirv::Value VisitExpr_(const Min* op) override; - spirv::Value VisitExpr_(const Max* op) override; + spirv::Value VisitExpr_(const MinNode* op) override; + spirv::Value VisitExpr_(const MaxNode* op) override; spirv::Value VisitExpr_(const LT* op) override; spirv::Value VisitExpr_(const LE* op) override; spirv::Value VisitExpr_(const GT* op) override; diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index 21f7df811fea..d47c0a41f9ea 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -320,7 +320,7 @@ void CodeGenStackVM::VisitExpr_(const ModNode* op) { PushBinary(StackVM::MOD_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const Min* op) { +void CodeGenStackVM::VisitExpr_(const MinNode* op) { this->Push(op->a); this->Push(op->b); this->PushOp(StackVM::PUSH_VALUE, -1); @@ -329,7 +329,7 @@ void CodeGenStackVM::VisitExpr_(const Min* op) { this->PushOp(StackVM::SELECT); } -void CodeGenStackVM::VisitExpr_(const Max* op) { +void CodeGenStackVM::VisitExpr_(const MaxNode* op) { this->Push(op->a); this->Push(op->b); this->PushOp(StackVM::PUSH_VALUE, 0); diff --git a/src/codegen/stackvm/codegen_stackvm.h b/src/codegen/stackvm/codegen_stackvm.h index 283a31dfd754..41ba2671756d 100644 --- a/src/codegen/stackvm/codegen_stackvm.h +++ b/src/codegen/stackvm/codegen_stackvm.h @@ -120,8 +120,8 @@ class CodeGenStackVM void VisitExpr_(const MulNode* op) final; void VisitExpr_(const DivNode* op) final; void VisitExpr_(const ModNode* op) final; - void VisitExpr_(const Min* op) final; - void VisitExpr_(const Max* op) final; + void VisitExpr_(const MinNode* op) final; + void VisitExpr_(const MaxNode* op) final; void VisitExpr_(const EQ* op) final; void VisitExpr_(const NE* op) final; void VisitExpr_(const LT* op) final; diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 484d589dea7e..6f066b2ca50a 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -158,7 +158,7 @@ void CodeGenHybrid::VisitExpr_(const DivNode* op, std::ostream& os) { // NOLINT PrintBinaryExpr(op, "/", os, this); } -void CodeGenHybrid::VisitExpr_(const FloorDiv* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const FloorDivNode* op, std::ostream& os) { // NOLINT(*) if (op->dtype.is_int()) PrintBinaryExpr(op, "//", os, this); else @@ -169,13 +169,13 @@ void CodeGenHybrid::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT PrintBinaryExpr(op, "%", os, this); } -void CodeGenHybrid::VisitExpr_(const FloorMod* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const FloorModNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "%", os, this); } -void CodeGenHybrid::VisitExpr_(const Min* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "min", os, this); } -void CodeGenHybrid::VisitExpr_(const Max* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "max", os, this); } void CodeGenHybrid::VisitExpr_(const EQ* op, std::ostream& os) { // NOLINT(*) diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 63bf1f74eb9d..5a556f7b5d98 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -99,10 +99,10 @@ class CodeGenHybrid : void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const FloorDiv* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const FloorMod* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const Min* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const Max* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloorDivNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloorModNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const EQ* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const NE* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LT* op, std::ostream& os) override; // NOLINT(*) diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index bf18085c2a61..f0b53217de1d 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -87,10 +87,10 @@ class AttrFunctor { virtual R VisitAttr_(const ir::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::DivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::ModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::FloorDiv* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::FloorMod* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::FloorDivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::FloorModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::MinNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::MaxNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::GE* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::GT* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::LT* op, Args... args) ATTR_FUNCTOR_DEFAULT; @@ -122,10 +122,10 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(MulNode); ATTR_FUNCTOR_DISPATCH(DivNode); ATTR_FUNCTOR_DISPATCH(ModNode); - ATTR_FUNCTOR_DISPATCH(FloorDiv); - ATTR_FUNCTOR_DISPATCH(FloorMod); - ATTR_FUNCTOR_DISPATCH(Min); - ATTR_FUNCTOR_DISPATCH(Max); + ATTR_FUNCTOR_DISPATCH(FloorDivNode); + ATTR_FUNCTOR_DISPATCH(FloorModNode); + ATTR_FUNCTOR_DISPATCH(MinNode); + ATTR_FUNCTOR_DISPATCH(MaxNode); ATTR_FUNCTOR_DISPATCH(GE); ATTR_FUNCTOR_DISPATCH(GT); ATTR_FUNCTOR_DISPATCH(LE); @@ -165,10 +165,10 @@ class AttrsEqualHandler : bool VisitAttr_(const ir::MulNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::DivNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::ModNode* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::FloorDiv* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::FloorMod* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::Min* lhs, const ObjectRef& other) final; - bool VisitAttr_(const ir::Max* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::FloorDivNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::FloorModNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::MinNode* lhs, const ObjectRef& other) final; + bool VisitAttr_(const ir::MaxNode* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::GE* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::GT* lhs, const ObjectRef& other) final; bool VisitAttr_(const ir::LT* lhs, const ObjectRef& other) final; @@ -208,10 +208,10 @@ class AttrsHashHandler : size_t VisitAttr_(const ir::MulNode* op) final; size_t VisitAttr_(const ir::DivNode* op) final; size_t VisitAttr_(const ir::ModNode* op) final; - size_t VisitAttr_(const ir::FloorDiv* op) final; - size_t VisitAttr_(const ir::FloorMod* op) final; - size_t VisitAttr_(const ir::Min* op) final; - size_t VisitAttr_(const ir::Max* op) final; + size_t VisitAttr_(const ir::FloorDivNode* op) final; + size_t VisitAttr_(const ir::FloorModNode* op) final; + size_t VisitAttr_(const ir::MinNode* op) final; + size_t VisitAttr_(const ir::MaxNode* op) final; size_t VisitAttr_(const ir::GE* op) final; size_t VisitAttr_(const ir::GT* op) final; size_t VisitAttr_(const ir::LE* op) final; diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index d17b2d905430..6fc2100870ad 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -156,10 +156,10 @@ TVM_DEFINE_ATTRS_BINOP_EQUAL(SubNode); TVM_DEFINE_ATTRS_BINOP_EQUAL(MulNode); TVM_DEFINE_ATTRS_BINOP_EQUAL(DivNode); TVM_DEFINE_ATTRS_BINOP_EQUAL(ModNode); -TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDiv); -TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorMod); -TVM_DEFINE_ATTRS_BINOP_EQUAL(Max); -TVM_DEFINE_ATTRS_BINOP_EQUAL(Min); +TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDivNode); +TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorModNode); +TVM_DEFINE_ATTRS_BINOP_EQUAL(MaxNode); +TVM_DEFINE_ATTRS_BINOP_EQUAL(MinNode); TVM_DEFINE_ATTRS_BINOP_EQUAL(GE); TVM_DEFINE_ATTRS_BINOP_EQUAL(GT); TVM_DEFINE_ATTRS_BINOP_EQUAL(LE); @@ -270,10 +270,10 @@ TVM_DEFINE_ATTRS_BINOP_HASH(SubNode); TVM_DEFINE_ATTRS_BINOP_HASH(MulNode); TVM_DEFINE_ATTRS_BINOP_HASH(DivNode); TVM_DEFINE_ATTRS_BINOP_HASH(ModNode); -TVM_DEFINE_ATTRS_BINOP_HASH(FloorDiv); -TVM_DEFINE_ATTRS_BINOP_HASH(FloorMod); -TVM_DEFINE_ATTRS_BINOP_HASH(Max); -TVM_DEFINE_ATTRS_BINOP_HASH(Min); +TVM_DEFINE_ATTRS_BINOP_HASH(FloorDivNode); +TVM_DEFINE_ATTRS_BINOP_HASH(FloorModNode); +TVM_DEFINE_ATTRS_BINOP_HASH(MaxNode); +TVM_DEFINE_ATTRS_BINOP_HASH(MinNode); TVM_DEFINE_ATTRS_BINOP_HASH(GE); TVM_DEFINE_ATTRS_BINOP_HASH(GT); TVM_DEFINE_ATTRS_BINOP_HASH(LE); diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 33c6a707fc79..d8482be007c8 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -31,8 +31,8 @@ namespace tvm { // TODO(tqchen): change to floormod/div -using IndexMod = ir::FloorMod; -using IndexDiv = ir::FloorDiv; +using IndexMod = ir::FloorModNode; +using IndexDiv = ir::FloorDivNode; Array SimplifyArray(Array array) { for (size_t i = 0; i < array.size(); ++i) { diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 7c8a791e7c24..cdab989aae28 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -278,18 +278,18 @@ Expr floordiv(Expr a, Expr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::FloorDiv::make(a, b); + return ir::FloorDivNode::make(a, b); } Expr floormod(Expr a, Expr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::FloorMod::make(a, b); + return ir::FloorModNode::make(a, b); } Expr min(Expr a, Expr b) { @@ -301,9 +301,9 @@ Expr min(Expr a, Expr b) { if (is_pos_inf(b)) return a; if (is_neg_inf(b)) return b; BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::Min::make(a, b); + return ir::MinNode::make(a, b); } Expr max(Expr a, Expr b) { @@ -315,9 +315,9 @@ Expr max(Expr a, Expr b) { if (is_pos_inf(b)) return b; if (is_neg_inf(b)) return a; BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::Max::make(a, b); + return ir::MaxNode::make(a, b); } Expr if_then_else(Expr cond, Expr true_value, Expr false_value) { @@ -557,7 +557,7 @@ Expr any(Expr source, Array rdom) { Expr max(Expr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::Max::make(x, y); + Expr result = ir::MaxNode::make(x, y); Expr identity_element = min_value(source.dtype()); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); @@ -566,7 +566,7 @@ Expr max(Expr source, Array rdom) { Expr min(Expr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::Min::make(x, y); + Expr result = ir::MinNode::make(x, y); Expr identity_element = max_value(source.dtype()); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); diff --git a/src/lang/ir.cc b/src/lang/ir.cc index cda2d3b49997..3b298769e1f7 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -645,16 +645,16 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << "min("; p->Print(op->a); p->stream << ", "; p->Print(op->b); p->stream << ")"; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << "max("; p->Print(op->a); p->stream << ", "; @@ -711,14 +711,14 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << "floordiv(" << op->a << ", " << op->b << ")"; }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << "floormod(" << op->a << ", " << op->b << ")"; }); @@ -1162,10 +1162,10 @@ TVM_REGISTER_NODE_TYPE(SubNode); TVM_REGISTER_NODE_TYPE(MulNode); TVM_REGISTER_NODE_TYPE(DivNode); TVM_REGISTER_NODE_TYPE(ModNode); -TVM_REGISTER_NODE_TYPE(FloorDiv); -TVM_REGISTER_NODE_TYPE(FloorMod); -TVM_REGISTER_NODE_TYPE(Min); -TVM_REGISTER_NODE_TYPE(Max); +TVM_REGISTER_NODE_TYPE(FloorDivNode); +TVM_REGISTER_NODE_TYPE(FloorModNode); +TVM_REGISTER_NODE_TYPE(MinNode); +TVM_REGISTER_NODE_TYPE(MaxNode); TVM_REGISTER_NODE_TYPE(EQ); TVM_REGISTER_NODE_TYPE(NE); TVM_REGISTER_NODE_TYPE(LT); diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc index 6ccf393e6a99..f93ed80575ac 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/pass/inject_copy_intrin.cc @@ -124,7 +124,7 @@ class CopyIntrinInjector : public StmtMutator { DataType t = loop_vars[i].dtype(); Expr svalue = src_shape[i]; if (min_value.defined()) { - Expr pbefore = Simplify(Max::make(min_value, make_zero(t))); + Expr pbefore = Simplify(MaxNode::make(min_value, make_zero(t))); src_elem_offset = src_elem_offset + pbefore * load_strides[i]; svalue = svalue - pbefore; pad_before.push_back(pbefore); @@ -132,7 +132,7 @@ class CopyIntrinInjector : public StmtMutator { pad_before.push_back(make_zero(t)); } if (max_value.defined()) { - Expr pafter = Simplify(Max::make(loops[i]->extent - max_value - make_const(t, 1), + Expr pafter = Simplify(MaxNode::make(loops[i]->extent - max_value - make_const(t, 1), make_zero(t))); svalue = svalue - pafter; pad_after.push_back(pafter); diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc index 62cd9ab629cf..2339563c80eb 100644 --- a/src/pass/ir_deep_compare.cc +++ b/src/pass/ir_deep_compare.cc @@ -303,10 +303,10 @@ class IRDeepCompare : DEFINE_BIOP_EXPR_CMP_(MulNode) DEFINE_BIOP_EXPR_CMP_(DivNode) DEFINE_BIOP_EXPR_CMP_(ModNode) - DEFINE_BIOP_EXPR_CMP_(FloorDiv) - DEFINE_BIOP_EXPR_CMP_(FloorMod) - DEFINE_BIOP_EXPR_CMP_(Min) - DEFINE_BIOP_EXPR_CMP_(Max) + DEFINE_BIOP_EXPR_CMP_(FloorDivNode) + DEFINE_BIOP_EXPR_CMP_(FloorModNode) + DEFINE_BIOP_EXPR_CMP_(MinNode) + DEFINE_BIOP_EXPR_CMP_(MaxNode) DEFINE_BIOP_EXPR_CMP_(EQ) DEFINE_BIOP_EXPR_CMP_(NE) DEFINE_BIOP_EXPR_CMP_(LT) diff --git a/src/pass/ir_functor.cc b/src/pass/ir_functor.cc index 27305917bd81..7e40aff8dfd8 100644 --- a/src/pass/ir_functor.cc +++ b/src/pass/ir_functor.cc @@ -246,10 +246,10 @@ DEFINE_BINOP_VISIT_(SubNode); DEFINE_BINOP_VISIT_(MulNode); DEFINE_BINOP_VISIT_(DivNode); DEFINE_BINOP_VISIT_(ModNode); -DEFINE_BINOP_VISIT_(FloorDiv); -DEFINE_BINOP_VISIT_(FloorMod); -DEFINE_BINOP_VISIT_(Min); -DEFINE_BINOP_VISIT_(Max); +DEFINE_BINOP_VISIT_(FloorDivNode); +DEFINE_BINOP_VISIT_(FloorModNode); +DEFINE_BINOP_VISIT_(MinNode); +DEFINE_BINOP_VISIT_(MaxNode); DEFINE_BINOP_VISIT_(EQ); DEFINE_BINOP_VISIT_(NE); DEFINE_BINOP_VISIT_(LT); @@ -661,10 +661,10 @@ DEFINE_BIOP_EXPR_MUTATE_(SubNode); DEFINE_BIOP_EXPR_MUTATE_(MulNode); DEFINE_BIOP_EXPR_MUTATE_(DivNode); DEFINE_BIOP_EXPR_MUTATE_(ModNode); -DEFINE_BIOP_EXPR_MUTATE_(FloorDiv); -DEFINE_BIOP_EXPR_MUTATE_(FloorMod); -DEFINE_BIOP_EXPR_MUTATE_(Min); -DEFINE_BIOP_EXPR_MUTATE_(Max); +DEFINE_BIOP_EXPR_MUTATE_(FloorDivNode); +DEFINE_BIOP_EXPR_MUTATE_(FloorModNode); +DEFINE_BIOP_EXPR_MUTATE_(MinNode); +DEFINE_BIOP_EXPR_MUTATE_(MaxNode); DEFINE_BIOP_EXPR_MUTATE_(EQ); DEFINE_BIOP_EXPR_MUTATE_(NE); DEFINE_BIOP_EXPR_MUTATE_(LT); diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index aa8ebe1eb19b..6ea2959757fc 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -506,7 +506,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre doubt loop"; - body_begin = Max::make(body_begin, min); + body_begin = MaxNode::make(body_begin, min); // stop recursing on this interval if we can't prove it has non-negative length pre_stmt_recurse = false; } @@ -532,7 +532,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, if (!analyzer_.CanProve(cond)) { LOG(WARNING) << "Cannot prove: " << cond << ", when generating the post doubt loop"; - post_doubt_begin = Min::make(post_doubt_begin, max+1); + post_doubt_begin = MinNode::make(post_doubt_begin, max+1); // stop recursing on this interval if we can't prove it has non-negative length post_stmt_recurse = false; } diff --git a/src/pass/lower_custom_datatypes.cc b/src/pass/lower_custom_datatypes.cc index 3c0439a49c15..e9c4632fe08b 100644 --- a/src/pass/lower_custom_datatypes.cc +++ b/src/pass/lower_custom_datatypes.cc @@ -116,8 +116,8 @@ class CustomDatatypesLowerer : public StmtExprMutator { DEFINE_MUTATE__(Mul, MulNode); DEFINE_MUTATE__(Div, DivNode); DEFINE_MUTATE__(Mod, ModNode); - DEFINE_MUTATE__(Min, Min); - DEFINE_MUTATE__(Max, Max); + DEFINE_MUTATE__(Min, MinNode); + DEFINE_MUTATE__(Max, MaxNode); DEFINE_MUTATE__(EQ, EQ); DEFINE_MUTATE__(NE, NE); DEFINE_MUTATE__(LT, LT); diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index 68b7253c401f..9ef5ad92cac3 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -73,10 +73,10 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // We use floordiv for integer analysis, // but will need to lower them to native truncdiv instructions - Expr VisitExpr_(const FloorDiv* op) final { + Expr VisitExpr_(const FloorDivNode* op) final { auto e = GetRef(op); Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); + op = ret.as(); if (op == nullptr) return ret; int shift; const DataType& dtype = op->dtype; @@ -120,9 +120,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } } - Expr VisitExpr_(const FloorMod* op) final { + Expr VisitExpr_(const FloorModNode* op) final { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); + op = ret.as(); if (op == nullptr) return ret; // Lower floordiv to native truncdiv. int shift; @@ -170,7 +170,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } } - Expr VisitExpr_(const Max* op) final { + Expr VisitExpr_(const MaxNode* op) final { using namespace arith; PVar x, y; PVar c; diff --git a/src/pass/rewrite_unsafe_select.cc b/src/pass/rewrite_unsafe_select.cc index 8886222d2dca..b1d840afbfe9 100644 --- a/src/pass/rewrite_unsafe_select.cc +++ b/src/pass/rewrite_unsafe_select.cc @@ -62,10 +62,10 @@ class UnsafeExprDetector : public ExprFunctor { bool VisitExpr_(const MulNode* op) final { return BinaryOp(op); } bool VisitExpr_(const DivNode* op) final { return BinaryOp(op); } bool VisitExpr_(const ModNode* op) final { return BinaryOp(op); } - bool VisitExpr_(const FloorDiv* op) final { return BinaryOp(op); } - bool VisitExpr_(const FloorMod* op) final { return BinaryOp(op); } - bool VisitExpr_(const Min* op) final { return BinaryOp(op); } - bool VisitExpr_(const Max* op) final { return BinaryOp(op); } + bool VisitExpr_(const FloorDivNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const FloorModNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const MinNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const MaxNode* op) final { return BinaryOp(op); } bool VisitExpr_(const EQ* op) final { return BinaryOp(op); } bool VisitExpr_(const NE* op) final { return BinaryOp(op); } bool VisitExpr_(const LT* op) final { return BinaryOp(op); } diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index 78066ef1b6ef..75e036caa537 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -147,16 +147,16 @@ class Vectorizer : public StmtExprMutator { Expr VisitExpr_(const ModNode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const FloorDiv* op) final { + Expr VisitExpr_(const FloorDivNode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const FloorMod* op) final { + Expr VisitExpr_(const FloorModNode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const Min* op) final { + Expr VisitExpr_(const MinNode* op) final { return BinaryVec(op); } - Expr VisitExpr_(const Max* op) final { + Expr VisitExpr_(const MaxNode* op) final { return BinaryVec(op); } Expr VisitExpr_(const EQ* op) final { diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index debfb36f936b..4b6915f7de93 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -38,7 +38,7 @@ TEST(ExprNodeRef, Basic) { using namespace tvm; Var x("x"); Expr z = max(x + 1 + 2, 100); - const ir::Max* op = z.as(); + const ir::MaxNode* op = z.as(); CHECK(GetRef(op).same_as(z)); } diff --git a/topi/include/topi/nn/pooling.h b/topi/include/topi/nn/pooling.h index c4cda6a20625..f58f73f8d2a7 100644 --- a/topi/include/topi/nn/pooling.h +++ b/topi/include/topi/nn/pooling.h @@ -155,11 +155,11 @@ inline Tensor pool_impl(const Tensor& x, } else { Expr h_start = output[height_axis] * stride_height - pad_top; Expr w_start = output[width_axis] * stride_width - pad_left; - Expr h_end = ir::Min::make(h_start + kernel_height, height); - Expr w_end = ir::Min::make(w_start + kernel_width, width); - h_start = ir::Max::make(h_start, make_const(DataType::DataType::Int(32), 0)); - w_start = ir::Max::make(w_start, make_const(DataType::DataType::Int(32), 0)); - Expr divide_factor = ir::Max::make((h_end - h_start) * (w_end - w_start), + Expr h_end = ir::MinNode::make(h_start + kernel_height, height); + Expr w_end = ir::MinNode::make(w_start + kernel_width, width); + h_start = ir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0)); + w_start = ir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0)); + Expr divide_factor = ir::MaxNode::make((h_end - h_start) * (w_end - w_start), make_const(DataType::DataType::Int(32), 1)); return div(pool_sum(indices), divide_factor); } @@ -308,12 +308,12 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, } else { Expr h_start = out_idx[height_axis] * stride_height - pad_top; Expr w_start = out_idx[width_axis] * stride_width - pad_left; - Expr h_end = ir::Min::make(h_start + kernel_height, height); - Expr w_end = ir::Min::make(w_start + kernel_width, width); - h_start = ir::Max::make(h_start, make_const(DataType::Int(32), 0)); - w_start = ir::Max::make(w_start, make_const(DataType::Int(32), 0)); + Expr h_end = ir::MinNode::make(h_start + kernel_height, height); + Expr w_end = ir::MinNode::make(w_start + kernel_width, width); + h_start = ir::MaxNode::make(h_start, make_const(DataType::Int(32), 0)); + w_start = ir::MaxNode::make(w_start, make_const(DataType::Int(32), 0)); divide_factor = - ir::Max::make((h_end - h_start) * (w_end - w_start), + ir::MaxNode::make((h_end - h_start) * (w_end - w_start), make_const(DataType::Int(32), 1)); } return tvm::sum(tvm::if_then_else( @@ -729,12 +729,12 @@ inline Tensor pool_impl_nd(const Tensor& x, for (int i = 0; i < k_size; i++) { int ii = axis[i]; start[i] = output[ii] * stride[i] - pad_head[i]; - end[i] = ir::Min::make(start[i] + kernel[i], x->shape[ii]); - start[i] = ir::Max::make(start[i], make_const(DataType::Int(32), 0)); + end[i] = ir::MinNode::make(start[i] + kernel[i], x->shape[ii]); + start[i] = ir::MaxNode::make(start[i], make_const(DataType::Int(32), 0)); kernel_size *= (end[i] - start[i]); } - Expr divide_factor = ir::Max::make(kernel_size, make_const(DataType::Int(32), 1)); + Expr divide_factor = ir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1)); return div(pool_sum(indices), divide_factor); } }, "tensor", kElementWise); From 6d52a308e525869872b3b2f5704dcf70f16751e8 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 7 Jan 2020 15:12:52 -0800 Subject: [PATCH 04/11] [REFACTOR][IR] EQ/NE/LT/LE/GT/GE/Select -> EQNode/NENode etc. --- include/tvm/ir.h | 28 ++++---- include/tvm/ir_functor_ext.h | 80 +++++++++++----------- src/api/api_ir.cc | 22 +++--- src/arithmetic/bound_deducer.cc | 18 ++--- src/arithmetic/canonical_simplify.cc | 2 +- src/arithmetic/const_fold.h | 18 ++--- src/arithmetic/const_int_bound.cc | 2 +- src/arithmetic/detect_linear_equation.cc | 10 +-- src/arithmetic/int_set.cc | 48 ++++++------- src/arithmetic/ir_mutator_with_analyzer.cc | 10 +-- src/arithmetic/ir_mutator_with_analyzer.h | 2 +- src/arithmetic/modular_set.cc | 2 +- src/arithmetic/pattern_match.h | 24 +++---- src/arithmetic/rewrite_simplify.cc | 48 ++++++------- src/arithmetic/rewrite_simplify.h | 20 +++--- src/codegen/codegen_c.cc | 20 +++--- src/codegen/codegen_c.h | 20 +++--- src/codegen/codegen_opencl.cc | 2 +- src/codegen/codegen_opencl.h | 2 +- src/codegen/llvm/codegen_llvm.cc | 16 ++--- src/codegen/llvm/codegen_llvm.h | 20 +++--- src/codegen/llvm/intrin_rule_llvm.cc | 2 +- src/codegen/spirv/codegen_spirv.cc | 20 +++--- src/codegen/spirv/codegen_spirv.h | 20 +++--- src/codegen/stackvm/codegen_stackvm.cc | 20 +++--- src/codegen/stackvm/codegen_stackvm.h | 20 +++--- src/contrib/hybrid/codegen_hybrid.cc | 20 +++--- src/contrib/hybrid/codegen_hybrid.h | 20 +++--- src/lang/attr_functor.h | 80 +++++++++++----------- src/lang/attrs.cc | 52 +++++++------- src/lang/expr_operator.cc | 42 ++++++------ src/lang/ir.cc | 76 ++++++++++---------- src/op/compute_op.cc | 2 +- src/pass/arg_binder.cc | 6 +- src/pass/bound_checker.cc | 4 +- src/pass/ir_deep_compare.cc | 24 +++---- src/pass/ir_functor.cc | 44 ++++++------ src/pass/loop_partition.cc | 24 +++---- src/pass/lower_custom_datatypes.cc | 12 ++-- src/pass/lower_intrin.cc | 12 ++-- src/pass/lower_thread_allreduce.cc | 2 +- src/pass/make_api.cc | 4 +- src/pass/rewrite_unsafe_select.cc | 24 +++---- src/pass/vectorize_loop.cc | 20 +++--- src/schedule/schedule_dataflow_rewrite.cc | 8 +-- tests/cpp/pattern_match_test.cc | 10 +-- topi/include/topi/elemwise.h | 4 +- topi/include/topi/nn.h | 8 +-- topi/include/topi/nn/pooling.h | 20 +++--- topi/include/topi/reduction.h | 8 +-- topi/include/topi/transform.h | 6 +- 51 files changed, 514 insertions(+), 514 deletions(-) diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 7b1127a047d9..42acb32885d6 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -235,43 +235,43 @@ class CmpOpNode : public ExprNode { }; /*! \brief a == b */ -class EQ : public CmpOpNode { +class EQNode : public CmpOpNode { public: static constexpr const char* _type_key = "EQ"; }; /*! \brief a != b */ -class NE : public CmpOpNode { +class NENode : public CmpOpNode { public: static constexpr const char* _type_key = "NE"; }; /*! \brief a < b */ -class LT : public CmpOpNode { +class LTNode : public CmpOpNode { public: static constexpr const char* _type_key = "LT"; }; /*! \brief a <= b */ -struct LE : public CmpOpNode { +struct LENode : public CmpOpNode { public: static constexpr const char* _type_key = "LE"; }; /*! \brief a > b */ -class GT : public CmpOpNode { +class GTNode : public CmpOpNode { public: static constexpr const char* _type_key = "GT"; }; /*! \brief a >= b */ -class GE : public CmpOpNode { +class GENode : public CmpOpNode { public: static constexpr const char* _type_key = "GE"; }; /*! \brief a && b */ -class And : public ExprNode { +class AndNode : public ExprNode { public: /*! \brief The left operand. */ Expr a; @@ -287,11 +287,11 @@ class And : public ExprNode { TVM_DLL static Expr make(Expr a, Expr b); static constexpr const char* _type_key = "And"; - TVM_DECLARE_FINAL_OBJECT_INFO(And, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, ExprNode); }; /*! \brief a || b */ -class Or : public ExprNode { +class OrNode : public ExprNode { public: /*! \brief The left operand. */ Expr a; @@ -307,11 +307,11 @@ class Or : public ExprNode { TVM_DLL static Expr make(Expr a, Expr b); static constexpr const char* _type_key = "Or"; - TVM_DECLARE_FINAL_OBJECT_INFO(Or, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, ExprNode); }; /*! \brief !a */ -class Not : public ExprNode { +class NotNode : public ExprNode { public: /*! \brief The input operand. */ Expr a; @@ -324,7 +324,7 @@ class Not : public ExprNode { TVM_DLL static Expr make(Expr a); static constexpr const char* _type_key = "Not"; - TVM_DECLARE_FINAL_OBJECT_INFO(Not, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, ExprNode); }; /*! @@ -334,7 +334,7 @@ class Not : public ExprNode { * Do not use it to guard against out of bound access, * please use if_then_else instead. */ -class Select : public ExprNode { +class SelectNode : public ExprNode { public: /*! \brief The condition */ Expr condition; @@ -353,7 +353,7 @@ class Select : public ExprNode { TVM_DLL static Expr make(Expr condition, Expr true_value, Expr false_value); static constexpr const char* _type_key = "Select"; - TVM_DECLARE_FINAL_OBJECT_INFO(Select, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, ExprNode); }; /*! diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index eb0601be4075..a6136105258f 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -145,18 +145,18 @@ class ExprFunctor { virtual R VisitExpr_(const FloorModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MinNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MaxNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const NE* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const LT* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const LE* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const GT* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const GE* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const And* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Or* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const EQNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const NENode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const LENode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GENode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const AndNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const OrNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Reduce* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const CastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Not* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const NotNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const SelectNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Broadcast* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const Shuffle* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -187,18 +187,18 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(FloorModNode); IR_EXPR_FUNCTOR_DISPATCH(MinNode); IR_EXPR_FUNCTOR_DISPATCH(MaxNode); - IR_EXPR_FUNCTOR_DISPATCH(EQ); - IR_EXPR_FUNCTOR_DISPATCH(NE); - IR_EXPR_FUNCTOR_DISPATCH(LT); - IR_EXPR_FUNCTOR_DISPATCH(LE); - IR_EXPR_FUNCTOR_DISPATCH(GT); - IR_EXPR_FUNCTOR_DISPATCH(GE); - IR_EXPR_FUNCTOR_DISPATCH(And); - IR_EXPR_FUNCTOR_DISPATCH(Or); + IR_EXPR_FUNCTOR_DISPATCH(EQNode); + IR_EXPR_FUNCTOR_DISPATCH(NENode); + IR_EXPR_FUNCTOR_DISPATCH(LTNode); + IR_EXPR_FUNCTOR_DISPATCH(LENode); + IR_EXPR_FUNCTOR_DISPATCH(GTNode); + IR_EXPR_FUNCTOR_DISPATCH(GENode); + IR_EXPR_FUNCTOR_DISPATCH(AndNode); + IR_EXPR_FUNCTOR_DISPATCH(OrNode); IR_EXPR_FUNCTOR_DISPATCH(Reduce); IR_EXPR_FUNCTOR_DISPATCH(CastNode); - IR_EXPR_FUNCTOR_DISPATCH(Not); - IR_EXPR_FUNCTOR_DISPATCH(Select); + IR_EXPR_FUNCTOR_DISPATCH(NotNode); + IR_EXPR_FUNCTOR_DISPATCH(SelectNode); IR_EXPR_FUNCTOR_DISPATCH(Ramp); IR_EXPR_FUNCTOR_DISPATCH(Shuffle); IR_EXPR_FUNCTOR_DISPATCH(Broadcast); @@ -311,18 +311,18 @@ class TVM_DLL ExprVisitor : void VisitExpr_(const FloorModNode* op) override; void VisitExpr_(const MinNode* op) override; void VisitExpr_(const MaxNode* op) override; - void VisitExpr_(const EQ* op) override; - void VisitExpr_(const NE* op) override; - void VisitExpr_(const LT* op) override; - void VisitExpr_(const LE* op) override; - void VisitExpr_(const GT* op) override; - void VisitExpr_(const GE* op) override; - void VisitExpr_(const And* op) override; - void VisitExpr_(const Or* op) override; + void VisitExpr_(const EQNode* op) override; + void VisitExpr_(const NENode* op) override; + void VisitExpr_(const LTNode* op) override; + void VisitExpr_(const LENode* op) override; + void VisitExpr_(const GTNode* op) override; + void VisitExpr_(const GENode* op) override; + void VisitExpr_(const AndNode* op) override; + void VisitExpr_(const OrNode* op) override; void VisitExpr_(const Reduce* op) override; void VisitExpr_(const CastNode* op) override; - void VisitExpr_(const Not* op) override; - void VisitExpr_(const Select* op) override; + void VisitExpr_(const NotNode* op) override; + void VisitExpr_(const SelectNode* op) override; void VisitExpr_(const Ramp* op) override; void VisitExpr_(const Broadcast* op) override; void VisitExpr_(const Shuffle* op) override; @@ -356,18 +356,18 @@ class TVM_DLL ExprMutator : Expr VisitExpr_(const FloorModNode* op) override; Expr VisitExpr_(const MinNode* op) override; Expr VisitExpr_(const MaxNode* op) override; - Expr VisitExpr_(const EQ* op) override; - Expr VisitExpr_(const NE* op) override; - Expr VisitExpr_(const LT* op) override; - Expr VisitExpr_(const LE* op) override; - Expr VisitExpr_(const GT* op) override; - Expr VisitExpr_(const GE* op) override; - Expr VisitExpr_(const And* op) override; - Expr VisitExpr_(const Or* op) override; + Expr VisitExpr_(const EQNode* op) override; + Expr VisitExpr_(const NENode* op) override; + Expr VisitExpr_(const LTNode* op) override; + Expr VisitExpr_(const LENode* op) override; + Expr VisitExpr_(const GTNode* op) override; + Expr VisitExpr_(const GENode* op) override; + Expr VisitExpr_(const AndNode* op) override; + Expr VisitExpr_(const OrNode* op) override; Expr VisitExpr_(const Reduce* op) override; Expr VisitExpr_(const CastNode* op) override; - Expr VisitExpr_(const Not* op) override; - Expr VisitExpr_(const Select* op) override; + Expr VisitExpr_(const NotNode* op) override; + Expr VisitExpr_(const SelectNode* op) override; Expr VisitExpr_(const Ramp* op) override; Expr VisitExpr_(const Broadcast* op) override; Expr VisitExpr_(const Shuffle* op) override; diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 9d5bd255a6d1..6e411389fea0 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -143,17 +143,17 @@ REGISTER_MAKE(FloorDivNode); REGISTER_MAKE(FloorModNode); REGISTER_MAKE(MinNode); REGISTER_MAKE(MaxNode); -REGISTER_MAKE(EQ); -REGISTER_MAKE(NE); -REGISTER_MAKE(LT); -REGISTER_MAKE(LE); -REGISTER_MAKE(GT); -REGISTER_MAKE(GE); -REGISTER_MAKE(And); -REGISTER_MAKE(Or); - -REGISTER_MAKE(Not); -REGISTER_MAKE(Select); +REGISTER_MAKE(EQNode); +REGISTER_MAKE(NENode); +REGISTER_MAKE(LTNode); +REGISTER_MAKE(LENode); +REGISTER_MAKE(GTNode); +REGISTER_MAKE(GENode); +REGISTER_MAKE(AndNode); +REGISTER_MAKE(OrNode); + +REGISTER_MAKE(NotNode); +REGISTER_MAKE(SelectNode); REGISTER_MAKE(Ramp); REGISTER_MAKE(CastNode); REGISTER_MAKE(Broadcast); diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 65574ef6327c..40f86de7561a 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -94,19 +94,19 @@ class BoundDeducer: public ExprVisitor { } } - void VisitExpr_(const LT* op) final { + void VisitExpr_(const LTNode* op) final { LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } - void VisitExpr_(const LE* op) final { + void VisitExpr_(const LENode* op) final { LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } - void VisitExpr_(const GT* op) final { + void VisitExpr_(const GTNode* op) final { LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } - void VisitExpr_(const GE* op) final { + void VisitExpr_(const GENode* op) final { LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } @@ -233,7 +233,7 @@ CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) { void BoundDeducer::Transform() { // We will ensure to set expr_ such that it contains target_ - if (const LT* op = expr_.as()) { + if (const LTNode* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a < b -> b >= a + 1 comp_op = kGreater; @@ -245,7 +245,7 @@ void BoundDeducer::Transform() { expr_ = op->a; result_ = op->b - 1; } - } else if (const LE* op = expr_.as()) { + } else if (const LENode* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a <= b -> b >= a comp_op = kGreater; @@ -256,7 +256,7 @@ void BoundDeducer::Transform() { expr_ = op->a; result_ = op->b; } - } else if (const GT* op = expr_.as()) { + } else if (const GTNode* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a > b -> b <= a - 1 comp_op = kLess; @@ -268,7 +268,7 @@ void BoundDeducer::Transform() { expr_ = op->a; result_ = op->b + 1; } - } else if (const GE* op = expr_.as()) { + } else if (const GENode* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a >= b -> b <= a comp_op = kLess; @@ -279,7 +279,7 @@ void BoundDeducer::Transform() { expr_ = op->a; result_ = op->b; } - } else if (const EQ* op = expr_.as()) { + } else if (const EQNode* op = expr_.as()) { comp_op = kEqual; if (GetPath(target_, op->a).empty()) { // if the b == a -> a == b diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 519b19c311d4..53a567a94a76 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -1106,7 +1106,7 @@ VisitExpr_(const Reduce* op) { // `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]` // instead of `op->source[op->value_index]`. The former may be more difficult to simplify. return this->VisitExpr( - Select::make(op->condition, + SelectNode::make(op->condition, op->source[op->value_index], op->combiner->identity_element[op->value_index])); } diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index 1ca01d75b7ee..cbdfe42efa7c 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -266,7 +266,7 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value > pb->value); if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value > fb->value); @@ -275,7 +275,7 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value >= pb->value); if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value >= fb->value); @@ -284,7 +284,7 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value < pb->value); if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value < fb->value); @@ -293,7 +293,7 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value <= pb->value); if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value <= fb->value); @@ -302,7 +302,7 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value == pb->value); if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value == fb->value); @@ -311,7 +311,7 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(DataType::UInt(1), pa->value != pb->value); if (fa && fb) return UIntImm::make(DataType::UInt(1), fa->value != fb->value); @@ -320,7 +320,7 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { using ir::UIntImm; const UIntImm* pa = a.as(); const UIntImm* pb = b.as(); @@ -332,7 +332,7 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a, Expr b) { +inline Expr TryConstFold(Expr a, Expr b) { using ir::UIntImm; const UIntImm* pa = a.as(); const UIntImm* pb = b.as(); @@ -344,7 +344,7 @@ inline Expr TryConstFold(Expr a, Expr b) { } template<> -inline Expr TryConstFold(Expr a) { +inline Expr TryConstFold(Expr a) { using ir::UIntImm; const UIntImm* pa = a.as(); if (pa) { diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index 06e437e2ed82..c29895bba09b 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -264,7 +264,7 @@ class ConstIntBoundAnalyzer::Impl : return ret; } - Entry VisitExpr_(const Select* op) final { + Entry VisitExpr_(const SelectNode* op) final { Entry a = VisitExpr(op->true_value); Entry b = VisitExpr(op->false_value); return Union(a, b); diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index 742e24c332da..7785801a5520 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -188,16 +188,16 @@ bool DetectClipBound( if (flag != 1) return false; // canonical form: exp >= 0 Expr canonical; - if (const LT* op = cond.as()) { + if (const LTNode* op = cond.as()) { if (!op->a.dtype().is_int()) return false; canonical = op->b - op->a - make_const(op->a.dtype(), 1); - } else if (const LE* op = cond.as()) { + } else if (const LENode* op = cond.as()) { if (!op->a.dtype().is_int()) return false; canonical = op->b - op->a; - } else if (const GT* op = cond.as()) { + } else if (const GTNode* op = cond.as()) { if (!op->a.dtype().is_int()) return false; canonical = op->a - op->b - make_const(op->a.dtype(), 1); - } else if (const GE* op = cond.as()) { + } else if (const GENode* op = cond.as()) { if (!op->a.dtype().is_int()) return false; canonical = op->a - op->b; } else { @@ -243,7 +243,7 @@ void SplitCommExpr(const Expr& e, std::vector* ret) { // e must be connected by and. Array DetectClipBound(const Expr& e, const Array& vars) { std::vector splits; - SplitCommExpr(e, &splits); + SplitCommExpr(e, &splits); std::unordered_map rmap; for (Var v : vars) { rmap[v.get()] = IntervalEntry(); diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 8cb52271b833..88ead8829681 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -83,15 +83,15 @@ struct is_logical_op { static const bool value = true; \ }; -TVM_DECLARE_LOGICAL_OP(And); -TVM_DECLARE_LOGICAL_OP(Or); -TVM_DECLARE_LOGICAL_OP(EQ); -TVM_DECLARE_LOGICAL_OP(NE); -TVM_DECLARE_LOGICAL_OP(GE); -TVM_DECLARE_LOGICAL_OP(GT); -TVM_DECLARE_LOGICAL_OP(LE); -TVM_DECLARE_LOGICAL_OP(LT); -TVM_DECLARE_LOGICAL_OP(Not); +TVM_DECLARE_LOGICAL_OP(AndNode); +TVM_DECLARE_LOGICAL_OP(OrNode); +TVM_DECLARE_LOGICAL_OP(EQNode); +TVM_DECLARE_LOGICAL_OP(NENode); +TVM_DECLARE_LOGICAL_OP(GENode); +TVM_DECLARE_LOGICAL_OP(GTNode); +TVM_DECLARE_LOGICAL_OP(LENode); +TVM_DECLARE_LOGICAL_OP(LTNode); +TVM_DECLARE_LOGICAL_OP(NotNode); /*! * \brief Combine two interval set under arithmetic operations. @@ -178,11 +178,11 @@ inline IntervalSet Combine(Analyzer* analyzer, Expr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using ir::Select; + using ir::SelectNode; Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); Expr e1 = a->min_value * b->min_value; Expr e2 = a->max_value * b->min_value; - return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1)); + return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1)); } } DLOG(WARNING) << "Return Everything in CombineInterval Mul"; @@ -213,11 +213,11 @@ inline IntervalSet Combine(Analyzer* analyzer, Expr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using ir::Select; + using ir::SelectNode; Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); Expr e1 = a->min_value / b->min_value; Expr e2 = a->max_value / b->min_value; - return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1)); + return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1)); } } DLOG(WARNING) << "Return Everything in CombineInterval Div"; @@ -279,11 +279,11 @@ inline IntervalSet Combine(Analyzer* analyzer, Expr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using ir::Select; + using ir::SelectNode; Expr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); Expr e1 = floordiv(a->min_value, b->min_value); Expr e2 = floordiv(a->max_value, b->min_value); - return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1)); + return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1)); } } DLOG(WARNING) << "Return Everything in CombineInterval Div"; @@ -441,35 +441,35 @@ class IntervalSetEvaluator : return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const EQ* op) final { + IntervalSet VisitExpr_(const EQNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const NE* op) final { + IntervalSet VisitExpr_(const NENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const LT* op) final { + IntervalSet VisitExpr_(const LTNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const LE* op) final { + IntervalSet VisitExpr_(const LENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const GT* op) final { + IntervalSet VisitExpr_(const GTNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const GE* op) final { + IntervalSet VisitExpr_(const GENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const And* op) final { + IntervalSet VisitExpr_(const AndNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const Or* op) final { + IntervalSet VisitExpr_(const OrNode* op) final { return VisitBinaryExpr_(op); } @@ -501,7 +501,7 @@ class IntervalSetEvaluator : return VisitExpr(op->value); } - IntervalSet VisitExpr_(const Select* op) final { + IntervalSet VisitExpr_(const SelectNode* op) final { IntervalSet true_set = this->Eval(op->true_value); IntervalSet false_set = this->Eval(op->false_value); return Union(analyzer_, false_set, true_set); diff --git a/src/arithmetic/ir_mutator_with_analyzer.cc b/src/arithmetic/ir_mutator_with_analyzer.cc index bfce2c26fbe3..97da9e667dfd 100644 --- a/src/arithmetic/ir_mutator_with_analyzer.cc +++ b/src/arithmetic/ir_mutator_with_analyzer.cc @@ -66,7 +66,7 @@ VisitStmt_(const IfThenElse* op) { } if (op->else_case.defined()) { With ctx(analyzer_, - analyzer_->rewrite_simplify(Not::make(condition))); + analyzer_->rewrite_simplify(NotNode::make(condition))); else_case = this->VisitStmt(op->else_case); } if (is_one(condition)) return then_case; @@ -137,7 +137,7 @@ VisitExpr_(const Call* op) { } { With constraint(analyzer_, - analyzer_->rewrite_simplify(Not::make(cond))); + analyzer_->rewrite_simplify(NotNode::make(cond))); false_value = this->VisitExpr(op->args[2]); } if (is_zero(cond)) { @@ -177,7 +177,7 @@ VisitExpr_(const Let* op) { } Expr IRMutatorWithAnalyzer:: -VisitExpr_(const Select* op) { +VisitExpr_(const SelectNode* op) { Expr cond = this->VisitExpr(op->condition); Expr true_value, false_value; { @@ -186,7 +186,7 @@ VisitExpr_(const Select* op) { } { With constraint(analyzer_, - analyzer_->rewrite_simplify(Not::make(cond))); + analyzer_->rewrite_simplify(NotNode::make(cond))); false_value = VisitExpr(op->false_value); } if (is_zero(cond)) { @@ -201,7 +201,7 @@ VisitExpr_(const Select* op) { false_value.same_as(op->false_value)) { return GetRef(op); } else { - return Select::make(cond, true_value, false_value); + return SelectNode::make(cond, true_value, false_value); } } diff --git a/src/arithmetic/ir_mutator_with_analyzer.h b/src/arithmetic/ir_mutator_with_analyzer.h index 9e3a86bb5280..30e7619191b0 100644 --- a/src/arithmetic/ir_mutator_with_analyzer.h +++ b/src/arithmetic/ir_mutator_with_analyzer.h @@ -55,7 +55,7 @@ class IRMutatorWithAnalyzer : public ir::StmtExprMutator { Stmt VisitStmt_(const ir::AttrStmt* op) override; Stmt VisitStmt_(const ir::AssertStmt* op) override; Expr VisitExpr_(const ir::Let* op) override; - Expr VisitExpr_(const ir::Select* op) override; + Expr VisitExpr_(const ir::SelectNode* op) override; Expr VisitExpr_(const ir::Call* op) override; Expr VisitExpr_(const ir::Reduce* op) override; diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 09e9ee31ab44..ec8ce0c8d6de 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -216,7 +216,7 @@ class ModularSetAnalyzer::Impl : return Union(a, b); } - Entry VisitExpr_(const Select* op) final { + Entry VisitExpr_(const SelectNode* op) final { Entry a = VisitExpr(op->true_value); Entry b = VisitExpr(op->false_value); return Union(a, b); diff --git a/src/arithmetic/pattern_match.h b/src/arithmetic/pattern_match.h index 6670326e8046..161a28421d55 100644 --- a/src/arithmetic/pattern_match.h +++ b/src/arithmetic/pattern_match.h @@ -341,14 +341,14 @@ TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDivNode); TVM_PATTERN_BINARY_OP(floormod, ir::FloorModNode); // logical expressions -TVM_PATTERN_BINARY_OP(operator>, ir::GT); -TVM_PATTERN_BINARY_OP(operator>=, ir::GE); -TVM_PATTERN_BINARY_OP(operator<, ir::LT); -TVM_PATTERN_BINARY_OP(operator<=, ir::LE); -TVM_PATTERN_BINARY_OP(operator==, ir::EQ); -TVM_PATTERN_BINARY_OP(operator!=, ir::NE); -TVM_PATTERN_BINARY_OP(operator&&, ir::And); -TVM_PATTERN_BINARY_OP(operator||, ir::Or); +TVM_PATTERN_BINARY_OP(operator>, ir::GTNode); +TVM_PATTERN_BINARY_OP(operator>=, ir::GENode); +TVM_PATTERN_BINARY_OP(operator<, ir::LTNode); +TVM_PATTERN_BINARY_OP(operator<=, ir::LENode); +TVM_PATTERN_BINARY_OP(operator==, ir::EQNode); +TVM_PATTERN_BINARY_OP(operator!=, ir::NENode); +TVM_PATTERN_BINARY_OP(operator&&, ir::AndNode); +TVM_PATTERN_BINARY_OP(operator||, ir::OrNode); /*! * \brief Pattern not expression. @@ -365,7 +365,7 @@ class PNotExpr : public Pattern > { } bool Match_(const ObjectRef& node) const { - if (const ir::Not* ptr = node.as()) { + if (const ir::NotNode* ptr = node.as()) { if (!value_.Match_(ptr->a)) return false; return true; } else { @@ -374,7 +374,7 @@ class PNotExpr : public Pattern > { } Expr Eval() const { - return ir::Not::make(value_.Eval()); + return ir::NotNode::make(value_.Eval()); } private: @@ -411,7 +411,7 @@ class PSelectExpr : } bool Match_(const ObjectRef& node) const { - if (const ir::Select* ptr = node.as()) { + if (const ir::SelectNode* ptr = node.as()) { if (!condition_.Match_(ptr->condition)) return false; if (!true_value_.Match_(ptr->true_value)) return false; if (!false_value_.Match_(ptr->false_value)) return false; @@ -422,7 +422,7 @@ class PSelectExpr : } Expr Eval() const { - return ir::Select::make( + return ir::SelectNode::make( condition_.Eval(), true_value_.Eval(), false_value_.Eval()); } diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index b8fec1fbc7bb..0c20a5ca84cb 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -1353,10 +1353,10 @@ VisitExpr_(const MaxNode* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const EQ* op) { +VisitExpr_(const EQNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1387,30 +1387,30 @@ VisitExpr_(const EQ* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const NE* op) { - return this->VisitExpr(Not::make(op->a == op->b)); +VisitExpr_(const NENode* op) { + return this->VisitExpr(NotNode::make(op->a == op->b)); } Expr RewriteSimplifier::Impl:: -VisitExpr_(const LE* op) { - return this->VisitExpr(Not::make(op->b < op->a)); +VisitExpr_(const LENode* op) { + return this->VisitExpr(NotNode::make(op->b < op->a)); } Expr RewriteSimplifier::Impl:: -VisitExpr_(const GT* op) { +VisitExpr_(const GTNode* op) { return this->VisitExpr(op->b < op->a); } Expr RewriteSimplifier::Impl:: -VisitExpr_(const GE* op) { - return this->VisitExpr(Not::make(op->a < op->b)); +VisitExpr_(const GENode* op) { + return this->VisitExpr(NotNode::make(op->a < op->b)); } Expr RewriteSimplifier::Impl:: -VisitExpr_(const LT* op) { +VisitExpr_(const LTNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1563,10 +1563,10 @@ VisitExpr_(const LT* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Not* op) { +VisitExpr_(const NotNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a); + op = ret.as(); + Expr const_res = TryConstFold(op->a); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y; @@ -1588,10 +1588,10 @@ VisitExpr_(const Not* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const And* op) { +VisitExpr_(const AndNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1637,10 +1637,10 @@ VisitExpr_(const And* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Or* op) { +VisitExpr_(const OrNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as(); - Expr const_res = TryConstFold(op->a, op->b); + op = ret.as(); + Expr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1687,9 +1687,9 @@ VisitExpr_(const Or* op) { } Expr RewriteSimplifier::Impl:: -VisitExpr_(const Select* op) { +VisitExpr_(const SelectNode* op) { Expr ret = IRMutatorWithAnalyzer::VisitExpr_(op); - op = ret.as()) { +bool AttrsEqualHandler::VisitAttr_(const SelectNode* lhs, const ObjectRef& other) { + if (const auto* rhs = other.as()) { return Equal(lhs->condition, rhs->condition) && Equal(lhs->true_value, rhs->true_value) && @@ -274,17 +274,17 @@ TVM_DEFINE_ATTRS_BINOP_HASH(FloorDivNode); TVM_DEFINE_ATTRS_BINOP_HASH(FloorModNode); TVM_DEFINE_ATTRS_BINOP_HASH(MaxNode); TVM_DEFINE_ATTRS_BINOP_HASH(MinNode); -TVM_DEFINE_ATTRS_BINOP_HASH(GE); -TVM_DEFINE_ATTRS_BINOP_HASH(GT); -TVM_DEFINE_ATTRS_BINOP_HASH(LE); -TVM_DEFINE_ATTRS_BINOP_HASH(LT); -TVM_DEFINE_ATTRS_BINOP_HASH(EQ); -TVM_DEFINE_ATTRS_BINOP_HASH(NE); -TVM_DEFINE_ATTRS_BINOP_HASH(And); -TVM_DEFINE_ATTRS_BINOP_HASH(Or); - -size_t AttrsHashHandler::VisitAttr_(const Not* op) { - static size_t key = std::hash()(Not::_type_key); +TVM_DEFINE_ATTRS_BINOP_HASH(GENode); +TVM_DEFINE_ATTRS_BINOP_HASH(GTNode); +TVM_DEFINE_ATTRS_BINOP_HASH(LENode); +TVM_DEFINE_ATTRS_BINOP_HASH(LTNode); +TVM_DEFINE_ATTRS_BINOP_HASH(EQNode); +TVM_DEFINE_ATTRS_BINOP_HASH(NENode); +TVM_DEFINE_ATTRS_BINOP_HASH(AndNode); +TVM_DEFINE_ATTRS_BINOP_HASH(OrNode); + +size_t AttrsHashHandler::VisitAttr_(const NotNode* op) { + static size_t key = std::hash()(NotNode::_type_key); return Combine(key, Hash(op->a)); } @@ -307,8 +307,8 @@ size_t AttrsHashHandler::VisitAttr_(const Call* op) { return res; } -size_t AttrsHashHandler::VisitAttr_(const Select* op) { - static size_t key = std::hash()(Select::_type_key); +size_t AttrsHashHandler::VisitAttr_(const SelectNode* op) { + static size_t key = std::hash()(SelectNode::_type_key); size_t res = key; res = Combine(res, Hash(op->condition)); res = Combine(res, Hash(op->true_value)); diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index cdab989aae28..b3f302808565 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -353,67 +353,67 @@ Expr likely(Expr cond) { Expr operator>(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::GT::make(a, b); + return ir::GTNode::make(a, b); } Expr operator>=(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::GE::make(a, b); + return ir::GENode::make(a, b); } Expr operator<(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::LT::make(a, b); + return ir::LTNode::make(a, b); } Expr operator<=(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::LE::make(a, b); + return ir::LENode::make(a, b); } Expr operator==(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::EQ::make(a, b); + return ir::EQNode::make(a, b); } Expr operator!=(Expr a, Expr b) { BinaryOpMatchTypes(a, b); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::NE::make(a, b); + return ir::NENode::make(a, b); } Expr operator&&(Expr a, Expr b) { CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::And::make(a, b); + return ir::AndNode::make(a, b); } Expr operator||(Expr a, Expr b) { CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); - Expr ret = arith::TryConstFold(a, b); + Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return ir::Or::make(a, b); + return ir::OrNode::make(a, b); } Expr operator!(Expr a) { CHECK(a.dtype().is_bool()); - Expr ret = arith::TryConstFold(a); + Expr ret = arith::TryConstFold(a); if (ret.defined()) return ret; - return ir::Not::make(a); + return ir::NotNode::make(a); } Expr operator>>(Expr a, Expr b) { @@ -485,7 +485,7 @@ Expr abs(Expr x) { if (px) { return ir::IntImm::make(x.dtype(), std::abs(px->value)); } - return ir::Select::make(x >= make_zero(x.dtype()), x, -x); + return ir::SelectNode::make(x >= make_zero(x.dtype()), x, -x); } else if (x.dtype().is_float()) { using ir::FloatImm; const FloatImm* fx = x.as(); @@ -538,7 +538,7 @@ Expr sum(Expr source, Array rdom) { Expr all(Expr source, Array rdom) { CHECK(source.dtype().is_bool()); Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::And::make(x, y); + Expr result = ir::AndNode::make(x, y); Expr identity_element = make_const(source.dtype(), true); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); @@ -548,7 +548,7 @@ Expr all(Expr source, Array rdom) { Expr any(Expr source, Array rdom) { CHECK(source.dtype().is_bool()); Var x("x", source.dtype()), y("y", source.dtype()); - Expr result = ir::Or::make(x, y); + Expr result = ir::OrNode::make(x, y); Expr identity_element = make_const(source.dtype(), false); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 3b298769e1f7..2c873926d02b 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -65,45 +65,45 @@ Expr CastNode::make(DataType t, Expr value) { return Expr(node); } -Expr And::make(Expr a, Expr b) { +Expr AndNode::make(Expr a, Expr b) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(b.defined()) << "ValueError: b is undefined"; CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); return Expr(node); } -Expr Or::make(Expr a, Expr b) { +Expr OrNode::make(Expr a, Expr b) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(b.defined()) << "ValueError: b is undefined"; CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types"; - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); return Expr(node); } -Expr Not::make(Expr a) { +Expr NotNode::make(Expr a) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(a.dtype().is_bool()); - ObjectPtr node = make_object(); + ObjectPtr node = make_object(); node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); return Expr(node); } -Expr Select::make(Expr condition, Expr true_value, Expr false_value) { +Expr SelectNode::make(Expr condition, Expr true_value, Expr false_value) { CHECK(condition.defined()) << "ValueError: condition is undefined"; CHECK(true_value.defined()) << "ValueError: true_value is undefined"; CHECK(false_value.defined()) << "ValueError: true_value is undefined"; @@ -111,7 +111,7 @@ Expr Select::make(Expr condition, Expr true_value, Expr false_value) { CHECK_EQ(condition.dtype().lanes(), true_value.dtype().lanes()); CHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types"; - ObjectPtr(); + ObjectPtr node = make_object(); node->dtype = true_value.dtype(); node->condition = std::move(condition); node->true_value = std::move(true_value); @@ -661,48 +661,48 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) p->Print(op->b); p->stream << ")"; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " == "; p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " != "; p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " < "; p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " <= "; p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " > "; p->Print(op->b); p->stream << ')'; }) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " >= "; @@ -723,8 +723,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " && "; @@ -733,8 +733,8 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '('; p->Print(op->a); p->stream << " || "; @@ -743,15 +743,15 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); p->stream << '!'; p->Print(op->a); }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch(); + void VisitExpr_(const SelectNode *op, const Expr& other) final { + const SelectNode* rhs = other.as(); if (CompareExpr(op->condition, rhs->condition) != 0) return; if (CompareExpr(op->true_value, rhs->true_value) != 0) return; if (CompareExpr(op->false_value, rhs->false_value) != 0) return; @@ -307,14 +307,14 @@ class IRDeepCompare : DEFINE_BIOP_EXPR_CMP_(FloorModNode) DEFINE_BIOP_EXPR_CMP_(MinNode) DEFINE_BIOP_EXPR_CMP_(MaxNode) - DEFINE_BIOP_EXPR_CMP_(EQ) - DEFINE_BIOP_EXPR_CMP_(NE) - DEFINE_BIOP_EXPR_CMP_(LT) - DEFINE_BIOP_EXPR_CMP_(LE) - DEFINE_BIOP_EXPR_CMP_(GT) - DEFINE_BIOP_EXPR_CMP_(GE) - DEFINE_BIOP_EXPR_CMP_(And) - DEFINE_BIOP_EXPR_CMP_(Or) + DEFINE_BIOP_EXPR_CMP_(EQNode) + DEFINE_BIOP_EXPR_CMP_(NENode) + DEFINE_BIOP_EXPR_CMP_(LTNode) + DEFINE_BIOP_EXPR_CMP_(LENode) + DEFINE_BIOP_EXPR_CMP_(GTNode) + DEFINE_BIOP_EXPR_CMP_(GENode) + DEFINE_BIOP_EXPR_CMP_(AndNode) + DEFINE_BIOP_EXPR_CMP_(OrNode) private: int CompareExpr(const Expr& lhs, const Expr& rhs) { diff --git a/src/pass/ir_functor.cc b/src/pass/ir_functor.cc index 7e40aff8dfd8..9e2c27a998dd 100644 --- a/src/pass/ir_functor.cc +++ b/src/pass/ir_functor.cc @@ -250,14 +250,14 @@ DEFINE_BINOP_VISIT_(FloorDivNode); DEFINE_BINOP_VISIT_(FloorModNode); DEFINE_BINOP_VISIT_(MinNode); DEFINE_BINOP_VISIT_(MaxNode); -DEFINE_BINOP_VISIT_(EQ); -DEFINE_BINOP_VISIT_(NE); -DEFINE_BINOP_VISIT_(LT); -DEFINE_BINOP_VISIT_(LE); -DEFINE_BINOP_VISIT_(GT); -DEFINE_BINOP_VISIT_(GE); -DEFINE_BINOP_VISIT_(And); -DEFINE_BINOP_VISIT_(Or); +DEFINE_BINOP_VISIT_(EQNode); +DEFINE_BINOP_VISIT_(NENode); +DEFINE_BINOP_VISIT_(LTNode); +DEFINE_BINOP_VISIT_(LENode); +DEFINE_BINOP_VISIT_(GTNode); +DEFINE_BINOP_VISIT_(GENode); +DEFINE_BINOP_VISIT_(AndNode); +DEFINE_BINOP_VISIT_(OrNode); void ExprVisitor::VisitExpr_(const IntImm* op) {} void ExprVisitor::VisitExpr_(const UIntImm* op) {} @@ -277,11 +277,11 @@ void ExprVisitor::VisitExpr_(const CastNode* op) { this->VisitExpr(op->value); } -void ExprVisitor::VisitExpr_(const Not* op) { +void ExprVisitor::VisitExpr_(const NotNode* op) { this->VisitExpr(op->a); } -void ExprVisitor::VisitExpr_(const Select* op) { +void ExprVisitor::VisitExpr_(const SelectNode* op) { this->VisitExpr(op->condition); this->VisitExpr(op->true_value); this->VisitExpr(op->false_value); @@ -665,14 +665,14 @@ DEFINE_BIOP_EXPR_MUTATE_(FloorDivNode); DEFINE_BIOP_EXPR_MUTATE_(FloorModNode); DEFINE_BIOP_EXPR_MUTATE_(MinNode); DEFINE_BIOP_EXPR_MUTATE_(MaxNode); -DEFINE_BIOP_EXPR_MUTATE_(EQ); -DEFINE_BIOP_EXPR_MUTATE_(NE); -DEFINE_BIOP_EXPR_MUTATE_(LT); -DEFINE_BIOP_EXPR_MUTATE_(LE); -DEFINE_BIOP_EXPR_MUTATE_(GT); -DEFINE_BIOP_EXPR_MUTATE_(GE); -DEFINE_BIOP_EXPR_MUTATE_(And); -DEFINE_BIOP_EXPR_MUTATE_(Or); +DEFINE_BIOP_EXPR_MUTATE_(EQNode); +DEFINE_BIOP_EXPR_MUTATE_(NENode); +DEFINE_BIOP_EXPR_MUTATE_(LTNode); +DEFINE_BIOP_EXPR_MUTATE_(LENode); +DEFINE_BIOP_EXPR_MUTATE_(GTNode); +DEFINE_BIOP_EXPR_MUTATE_(GENode); +DEFINE_BIOP_EXPR_MUTATE_(AndNode); +DEFINE_BIOP_EXPR_MUTATE_(OrNode); Expr ExprMutator::VisitExpr_(const Reduce* op) { auto fitervar = [this](const IterVar& v) { @@ -714,16 +714,16 @@ Expr ExprMutator::VisitExpr_(const CastNode* op) { } } -Expr ExprMutator::VisitExpr_(const Not* op) { +Expr ExprMutator::VisitExpr_(const NotNode* op) { Expr a = this->VisitExpr(op->a); if (a.same_as(op->a)) { return GetRef(op); } else { - return Not::make(a); + return NotNode::make(a); } } -Expr ExprMutator::VisitExpr_(const Select* op) { +Expr ExprMutator::VisitExpr_(const SelectNode* op) { Expr condition = this->VisitExpr(op->condition); Expr true_value = this->VisitExpr(op->true_value); Expr false_value = this->VisitExpr(op->false_value); @@ -732,7 +732,7 @@ Expr ExprMutator::VisitExpr_(const Select* op) { false_value.same_as(op->false_value)) { return GetRef(op); } else { - return Select::make(condition, true_value, false_value); + return SelectNode::make(condition, true_value, false_value); } } diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 6ea2959757fc..ec11deca0c6c 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -226,24 +226,24 @@ class PartitionFinder : public StmtExprVisitor { private: Expr InverseCond(const Expr& cond) { Expr inverse_cond; - if (const LT* op = cond.as()) { + if (const LTNode* op = cond.as()) { // a < b -> a >= b - inverse_cond = GE::make(op->a, op->b); - } else if (const GT* op = cond.as()) { + inverse_cond = GENode::make(op->a, op->b); + } else if (const GTNode* op = cond.as()) { // a > b -> a <= b - inverse_cond = LE::make(op->a, op->b); - } else if (const LE* op = cond.as()) { + inverse_cond = LENode::make(op->a, op->b); + } else if (const LENode* op = cond.as()) { // a <= b -> a > b - inverse_cond = GT::make(op->a, op->b); - } else if (const GE* op = cond.as()) { + inverse_cond = GTNode::make(op->a, op->b); + } else if (const GENode* op = cond.as()) { // a >= b -> a < b - inverse_cond = LT::make(op->a, op->b); - } else if (const EQ* op = cond.as()) { + inverse_cond = LTNode::make(op->a, op->b); + } else if (const EQNode* op = cond.as()) { // a == b -> a != b - inverse_cond = NE::make(op->a, op->b); + inverse_cond = NENode::make(op->a, op->b); // a != b -> a == b - } else if (const NE* op = cond.as()) { - inverse_cond = EQ::make(op->a, op->b); + } else if (const NENode* op = cond.as()) { + inverse_cond = EQNode::make(op->a, op->b); } return inverse_cond; } diff --git a/src/pass/lower_custom_datatypes.cc b/src/pass/lower_custom_datatypes.cc index e9c4632fe08b..603b2b2bd20c 100644 --- a/src/pass/lower_custom_datatypes.cc +++ b/src/pass/lower_custom_datatypes.cc @@ -118,12 +118,12 @@ class CustomDatatypesLowerer : public StmtExprMutator { DEFINE_MUTATE__(Mod, ModNode); DEFINE_MUTATE__(Min, MinNode); DEFINE_MUTATE__(Max, MaxNode); - DEFINE_MUTATE__(EQ, EQ); - DEFINE_MUTATE__(NE, NE); - DEFINE_MUTATE__(LT, LT); - DEFINE_MUTATE__(LE, LE); - DEFINE_MUTATE__(GT, GT); - DEFINE_MUTATE__(GE, GE); + DEFINE_MUTATE__(EQ, EQNode); + DEFINE_MUTATE__(NE, NENode); + DEFINE_MUTATE__(LT, LTNode); + DEFINE_MUTATE__(LE, LENode); + DEFINE_MUTATE__(GT, GTNode); + DEFINE_MUTATE__(GE, GENode); // Later changes may need to add more mutate functions as we support workloads with more ops. private: diff --git a/src/pass/lower_intrin.cc b/src/pass/lower_intrin.cc index 9ef5ad92cac3..a4fa71f3d1b7 100644 --- a/src/pass/lower_intrin.cc +++ b/src/pass/lower_intrin.cc @@ -104,7 +104,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // equivalent to rdiv + (rmod >= 0 ? 0: -1); return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1)); } else { - return ir::Select::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1)); + return ir::SelectNode::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1)); } } } else { @@ -114,7 +114,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1) Expr rdiv = truncdiv(op->a, op->b); Expr rmod = truncmod(op->a, op->b); - return ir::Select::make( + return ir::SelectNode::make( (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv, rdiv - make_const(dtype, 1)); } @@ -153,7 +153,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // -> rmod >= 0 ? 0 : b return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1))); } else { - return ir::Select::make(rmod >= 0, rmod, rmod + op->b); + return ir::SelectNode::make(rmod >= 0, rmod, rmod + op->b); } } } else { @@ -164,7 +164,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // b > 0 && rmod < 0 -> rmod + b // b < 0 && rmod < 0 -> rmod // b < 0 && rmod > 0 -> rmod + b - return ir::Select::make( + return ir::SelectNode::make( (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b); } @@ -183,7 +183,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return IRMutatorWithAnalyzer::VisitExpr_(op); } - Expr VisitExpr_(const EQ* op) final { + Expr VisitExpr_(const EQNode* op) final { using namespace arith; PVar x, y; auto e = GetRef(op); @@ -193,7 +193,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return IRMutatorWithAnalyzer::VisitExpr_(op); } - Expr VisitExpr_(const NE* op) final { + Expr VisitExpr_(const NENode* op) final { using namespace arith; PVar x, y; auto e = GetRef(op); diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index 4712bccb415a..7b1870dfef3a 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -130,7 +130,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t idx = 0; idx < size; ++idx) { values[idx] = call->args[1+idx]; if (!is_one(cond)) { - values[idx] = Select::make(cond, values[idx], inits[idx]); + values[idx] = SelectNode::make(cond, values[idx], inits[idx]); } types[idx] = values[idx].dtype(); } diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index 048289e24710..bfea089dd58c 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -240,10 +240,10 @@ class DeviceTypeBinder: public StmtExprMutator { return res; } - Expr VisitExpr_(const NE* op) final { + Expr VisitExpr_(const NENode* op) final { // eager check NE for device check Expr res = StmtExprMutator::VisitExpr_(op); - op = res.as(); + op = res.as(); if (ir::Equal(op->a, op->b)) { return make_const(op->dtype, false); } diff --git a/src/pass/rewrite_unsafe_select.cc b/src/pass/rewrite_unsafe_select.cc index b1d840afbfe9..2c6cefc077ba 100644 --- a/src/pass/rewrite_unsafe_select.cc +++ b/src/pass/rewrite_unsafe_select.cc @@ -35,7 +35,7 @@ class UnsafeExprDetector : public ExprFunctor { public: // select itself is always considered safe if condition is safe // Because we will issue guard to make sure it is. - bool VisitExpr_(const Select* op) { + bool VisitExpr_(const SelectNode* op) { return VisitExpr(op->condition); } bool VisitExpr_(const Call* op) { @@ -66,15 +66,15 @@ class UnsafeExprDetector : public ExprFunctor { bool VisitExpr_(const FloorModNode* op) final { return BinaryOp(op); } bool VisitExpr_(const MinNode* op) final { return BinaryOp(op); } bool VisitExpr_(const MaxNode* op) final { return BinaryOp(op); } - bool VisitExpr_(const EQ* op) final { return BinaryOp(op); } - bool VisitExpr_(const NE* op) final { return BinaryOp(op); } - bool VisitExpr_(const LT* op) final { return BinaryOp(op); } - bool VisitExpr_(const LE* op) final { return BinaryOp(op); } - bool VisitExpr_(const GT* op) final { return BinaryOp(op); } - bool VisitExpr_(const GE* op) final { return BinaryOp(op); } - bool VisitExpr_(const And* op) final { return BinaryOp(op); } - bool VisitExpr_(const Or* op) final { return BinaryOp(op); } - bool VisitExpr_(const Not* op) final { + bool VisitExpr_(const EQNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const NENode* op) final { return BinaryOp(op); } + bool VisitExpr_(const LTNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const LENode* op) final { return BinaryOp(op); } + bool VisitExpr_(const GTNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const GENode* op) final { return BinaryOp(op); } + bool VisitExpr_(const AndNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const OrNode* op) final { return BinaryOp(op); } + bool VisitExpr_(const NotNode* op) final { return VisitExpr(op->a); } bool VisitExpr_(const Let* op) final { @@ -110,9 +110,9 @@ class UnsafeExprDetector : public ExprFunctor { class UnsafeSelectRewriter : public StmtExprMutator { public: - Expr VisitExpr_(const Select* op) { + Expr VisitExpr_(const SelectNode* op) { Expr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as