Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 31 additions & 31 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class ConstIntBoundAnalyzer {
* \param expr The expression of interest.
* \return the result of the analysis.
*/
ConstIntBound operator()(const Expr& expr);
ConstIntBound operator()(const PrimExpr& expr);

/*!
* \brief Update constant int bound information of var.
Expand Down Expand Up @@ -136,7 +136,7 @@ class ConstIntBoundAnalyzer {
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
std::function<void()> EnterConstraint(const Expr& constraint);
std::function<void()> EnterConstraint(const PrimExpr& constraint);
struct Entry;
class Impl;
/*! \brief Internal impl */
Expand Down Expand Up @@ -192,7 +192,7 @@ class ModularSetAnalyzer {
* \param expr The expression of interest.
* \return the result of the analysis.
*/
ModularSet operator()(const Expr& expr);
ModularSet operator()(const PrimExpr& expr);
/*!
* \brief Update constant int bound information of var.
*
Expand All @@ -215,7 +215,7 @@ class ModularSetAnalyzer {
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
std::function<void()> EnterConstraint(const Expr& constraint);
std::function<void()> EnterConstraint(const PrimExpr& constraint);
struct Entry;
class Impl;
/*! \brief Internal impl */
Expand All @@ -232,7 +232,7 @@ class RewriteSimplifier {
* \param expr The expression of interest.
* \return the result of the analysis.
*/
Expr operator()(const Expr& expr);
PrimExpr operator()(const PrimExpr& expr);

/*!
* \brief Update binding of var to a new expression.
Expand All @@ -242,10 +242,10 @@ class RewriteSimplifier {
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const Expr& new_expr,
const PrimExpr& new_expr,
bool override = false);

std::function<void()> EnterConstraint(const Expr& constraint);
std::function<void()> EnterConstraint(const PrimExpr& constraint);

private:
friend class Analyzer;
Expand All @@ -268,7 +268,7 @@ class CanonicalSimplifier {
* \param expr The expression of interest.
* \return the result of the analysis.
*/
Expr operator()(const Expr& expr);
PrimExpr operator()(const PrimExpr& expr);

/*!
* \brief Update binding of var to a new expression.
Expand All @@ -278,7 +278,7 @@ class CanonicalSimplifier {
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const Expr& new_expr,
const PrimExpr& new_expr,
bool override = false);

private:
Expand Down Expand Up @@ -316,7 +316,7 @@ class ConstraintContext {
* \param analyzer The analyzer.
* \param constraint The constraint to be applied.
*/
ConstraintContext(Analyzer* analyzer, Expr constraint)
ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
: analyzer_(analyzer), constraint_(constraint) {}
// enter the scope.
void EnterWithScope();
Expand All @@ -325,7 +325,7 @@ class ConstraintContext {
/*! \brief The analyzer */
Analyzer* analyzer_;
/*! \brief The constraint */
Expr constraint_;
PrimExpr constraint_;
/*! \brief function to be called in recovery */
std::function<void()> exit_;
};
Expand Down Expand Up @@ -375,9 +375,9 @@ class IntSet : public ObjectRef {
*/
Range cover_range(Range max_range) const;
/*! \return Lower bound of the set */
Expr min() const;
PrimExpr min() const;
/*! \return upper bound of the set */
Expr max() const;
PrimExpr max() const;
/*! \return Whether the set represent nothing */
bool is_nothing() const;
/*! \return Whether the set represent everything */
Expand All @@ -398,7 +398,7 @@ class IntSet : public ObjectRef {
* \brief The single point value, call only if is_single_point is true
* \return The point value.
*/
Expr point_value() const;
PrimExpr point_value() const;
/*!
* \brief Try to match IntSet with range r.
*
Expand All @@ -415,13 +415,13 @@ class IntSet : public ObjectRef {
* \param point The point in the set.
* \return construct a single point set
*/
static IntSet single_point(Expr point);
static IntSet single_point(PrimExpr point);
/*!
* \brief construct a integer set from vector expression.
* \param vec The vector expression, can also be single point.
* \return The result set containing the indices in the vector.
*/
static IntSet vector(Expr vec);
static IntSet vector(PrimExpr vec);
/*!
* \brief Construct a set representing a range.
* \param r The range
Expand All @@ -434,7 +434,7 @@ class IntSet : public ObjectRef {
* \param max The maximum value of the interval.
* \return constructed set.
*/
static IntSet interval(Expr min, Expr max);
static IntSet interval(PrimExpr min, PrimExpr max);
};

/*!
Expand All @@ -450,7 +450,7 @@ class IntSetAnalyzer {
* \param dom_map The domain map to indicate which variable to relax.
* \return the result of the analysis.
*/
IntSet operator()(const Expr& expr, const Map<Var, IntSet>& dom_map);
IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);

private:
friend class Analyzer;
Expand Down Expand Up @@ -499,7 +499,7 @@ class Analyzer {
* \param var The variable.
* \param expr The expression we bind to.
*/
void Bind(const VarExpr& var, const Expr& expr);
void Bind(const Var& var, const PrimExpr& expr);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
Expand All @@ -509,7 +509,7 @@ class Analyzer {
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const VarExpr& var, const Range& range);
void Bind(const Var& var, const Range& range);
/*!
* \brief Whether can we prove expr >= val.

Expand All @@ -522,7 +522,7 @@ class Analyzer {
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
/*!
* \brief Whether can we prove condition.
*
Expand All @@ -531,7 +531,7 @@ class Analyzer {
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProve(const Expr& cond);
bool CanProve(const PrimExpr& cond);
/*!
* \brief Simplify expr.
*
Expand All @@ -540,7 +540,7 @@ class Analyzer {
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
Expr Simplify(const Expr& expr);
PrimExpr Simplify(const PrimExpr& expr);
};

//-----------------------------------------------
Expand All @@ -554,7 +554,7 @@ class Analyzer {
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Expr e,
IntSet EvalSet(PrimExpr e,
const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
Expand All @@ -563,7 +563,7 @@ IntSet EvalSet(Expr e,
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Expr e,
IntSet EvalSet(PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& dom_map);

/*!
Expand Down Expand Up @@ -598,7 +598,7 @@ IntSet EvalSet(Range r,
const std::unordered_map<const VarNode*, IntSet>& dom_map);

/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectHash, ObjectEqual>;
/*!
* \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables.
Expand All @@ -608,7 +608,7 @@ using ExprIntSetMap = std::unordered_map<Expr, IntSet, ObjectHash, ObjectEqual>;
* \return the map from the expression to its possible value.
*/
ExprIntSetMap EvalSetForEachSubExpr(
Expr e,
PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& dom_map);

/*!
Expand Down Expand Up @@ -640,7 +640,7 @@ IntSet Intersect(const Array<IntSet>& sets);
* The deduce bound must implies e for all value in relax_map
* \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(Expr v, Expr cond,
IntSet DeduceBound(PrimExpr v, PrimExpr cond,
const Map<Var, IntSet>& hint_map,
const Map<Var, IntSet>& relax_map);
/*!
Expand All @@ -653,7 +653,7 @@ IntSet DeduceBound(Expr v, Expr cond,
* The deduce bound mush implies e for all value in relax_map
* \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(Expr v, Expr cond,
IntSet DeduceBound(PrimExpr v, PrimExpr cond,
const std::unordered_map<const VarNode*, IntSet>& hint_map,
const std::unordered_map<const VarNode*, IntSet>& relax_map);

Expand All @@ -676,7 +676,7 @@ Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
*/
Array<Expr> DetectLinearEquation(const Expr& e,
Array<PrimExpr> DetectLinearEquation(const PrimExpr& e,
const Array<Var>& vars);

/*!
Expand All @@ -687,7 +687,7 @@ Array<Expr> DetectLinearEquation(const Expr& e,
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
*/
Array<Expr> DetectClipBound(const Expr& e,
Array<PrimExpr> DetectClipBound(const PrimExpr& e,
const Array<Var>& vars);

// implementation
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) {
if (val.type_code() == kDLInt) {
*ptr = static_cast<T>(val.value().v_int64);
} else {
Expr expr = val;
PrimExpr expr = val;
CHECK(expr.defined());
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<T>(op->value);
Expand All @@ -502,7 +502,7 @@ inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
if (val.type_code() == kStr) {
*ptr = val.operator std::string();
} else {
Expr expr = val;
PrimExpr expr = val;
const ir::StringImmNode* op = expr.as<ir::StringImmNode>();
CHECK(op != nullptr);
*ptr = op->value;
Expand All @@ -517,7 +517,7 @@ inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
*ptr = val.operator double();
} else {
Expr expr = val;
PrimExpr expr = val;
CHECK(expr.defined());
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<double>(op->value);
Expand Down
24 changes: 12 additions & 12 deletions include/tvm/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,30 +66,30 @@ class Buffer : public ObjectRef {
* If stride is not needed in the slice, it won't be presented
* \return the result buffer.
*/
TVM_DLL Buffer MakeSlice(Array<Expr> begins, Array<Expr> extents) const;
TVM_DLL Buffer MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const;
/*!
* \brief Get access ptr to the entire buffer.
* \param access_mask The access mask
* \param ptr_type The type of the pointer.
* \param content_lanes The number of lanes for the (data) type.
* \param offset The offset of ptr.
*/
TVM_DLL Expr access_ptr(int access_mask,
TVM_DLL PrimExpr access_ptr(int access_mask,
DataType ptr_type = DataType::Handle(),
int content_lanes = 1,
Expr offset = make_const(DataType::Int(32), 0)) const;
PrimExpr offset = make_const(DataType::Int(32), 0)) const;
/*!
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
* \param dtype The data type to be loaded.
*/
TVM_DLL Expr vload(Array<Expr> begin, DataType dtype) const;
TVM_DLL PrimExpr vload(Array<PrimExpr> begin, DataType dtype) const;
/*!
* \brief Create a Stmt that does a vector store at begin index.
* \param begin The beginning index
* \param value The value to be stored.
*/
TVM_DLL Stmt vstore(Array<Expr> begin, Expr value) const;
TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand All @@ -112,14 +112,14 @@ class BufferNode : public Object {
/*! \brief data type in the content of the tensor */
DataType dtype;
/*! \brief The shape of the buffer */
Array<Expr> shape;
Array<PrimExpr> shape;
/*!
* \brief The strides of each dimension
* This can be an empty array, indicating array is contiguous
*/
Array<Expr> strides;
Array<PrimExpr> strides;
/*! \brief The offset in terms of number of dtype elements (including lanes) */
Expr elem_offset;
PrimExpr elem_offset;
// Meta data
/*! \brief optional name of the buffer */
std::string name;
Expand Down Expand Up @@ -159,9 +159,9 @@ class BufferNode : public Object {
// A default value will be picked.
TVM_DLL static Buffer make(Var ptr,
DataType dtype,
Array<Expr> shape,
Array<Expr> strides,
Expr elem_offset,
Array<PrimExpr> shape,
Array<PrimExpr> strides,
PrimExpr elem_offset,
std::string name,
std::string scope,
int data_alignment,
Expand All @@ -184,7 +184,7 @@ inline const BufferNode* Buffer::operator->() const {
* \return The created buffer.
* \sa BufferNode::make for complete constructor.
*/
TVM_DLL Buffer decl_buffer(Array<Expr> shape,
TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape,
DataType dtype = DataType::Float(32),
std::string name = "buffer");
} // namespace tvm
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ class TargetNode : public Object {
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1;
/*! \brief Keys for this target */
Array<Expr> keys_array;
Array<PrimExpr> keys_array;
/*! \brief Options for this target */
Array<Expr> options_array;
Array<PrimExpr> options_array;
/*! \brief Collection of imported libs */
Array<Expr> libs_array;
Array<PrimExpr> libs_array;

/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;
Expand Down
Loading