Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ class DataType {
bool is_vector_bool() const { return is_vector() && bits() == 1; }
/*! \return whether type is a Void type. */
bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; }
/* ! \return whether type is a signed or an unsigned int. */
bool is_integer_type() const { return is_int() || is_uint(); }
/*!
* \brief Create a new data type by change lanes to a specified value.
* \param lanes The target number of lanes.
Expand Down
95 changes: 94 additions & 1 deletion include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,6 @@ inline bool is_const_int(const PrimExpr& x);

/*!
* \brief Check whether x is an integer/float constant.
* \note This only return true for integer types.
* \return whether x is constant
*/
inline bool is_const_number(const PrimExpr& x);
Expand Down Expand Up @@ -882,6 +881,30 @@ inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr
*/
TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift);

/*!
* \brief Try to find the narrowest type which can still represent the value of x.
* \param x The input expression.
* \param round_to_bytes Whether the result should be a power-of-two number of bytes.
* \return DataType that can hold x, but possibly with fewer bits. For immediate values
* of integer types, if the value is non-negative, the returned type will be
* an unsigned with the minimal number of required bits (subject to rounding).
* For floating point values the type code will remain unchanged.
* \note The rounding to bytes does not apply to bool type.
*/
inline DataType restricted_type(const PrimExpr& x, bool round_to_bytes = false);

/*!
* \brief Try to set the dtype of the expression to the specified type without
* modifying the expression itself (i.e. without adding casts, etc.), if the
* new type preserves all possible values of the expression.
*
* \param t The desired output type.
* \param x The input expression.
* \return The expression x with a new dtype t, if x with type t assumes the same
* values as x, or NullOpt otherwise.
*/
inline Optional<PrimExpr> try_reset_expr_dtype(DataType t, const PrimExpr& x);

// Implementation details after this
inline bool is_const_int(const PrimExpr& x) { return as_const_int(x); }

Expand Down Expand Up @@ -973,6 +996,76 @@ inline PrimExpr make_zero(DataType t, Span span) {
return make_const(t, 0, span);
}

inline Optional<PrimExpr> try_reset_expr_dtype(DataType t, const PrimExpr& x) {
Optional<PrimExpr> none = NullOpt;
if (!is_const_number(x)) {
return x.dtype() == t ? x : none;
}

DataType narrow_t = restricted_type(x, false);
if (narrow_t.code() == t.code()) {
return narrow_t.bits() <= t.bits() ? tvm::cast(t, x) : none;
}
if (narrow_t.is_int() && t.is_uint()) {
// Non-negative integer immediates will always have the restricted type of uint,
// so if the type is int, the immediate must be negative, hence not representable
// in an unsigned type.
return none;
} else if (narrow_t.is_uint() && t.is_int()) {
// The non-negative immediate can be switched to a signed type, if the signed type
// has at least one more bit.
return narrow_t.bits() < t.bits() ? tvm::cast(t, x) : none;
}
return none;
}

inline DataType restricted_type(const PrimExpr& x, bool round_to_bytes) {
if (const int64_t* val = as_const_int(x)) {
int64_t v = *val;
if (x.dtype().is_integer_type() && (v == 0 || v == 1)) {
return DataType::Bool();
}
auto num_significant_bits = [](uint64_t t) -> unsigned {
// Always return at least 1.
if (t == 0) return 1;
#ifdef __GNUC__
return 64 - __builtin_clzll(t);
#else
for (int i = 8 * sizeof(t) - 1; i > 0; --i) {
if (((t << i) >> i) == t) return 64 - i;
}
return 64;
#endif
};
auto round_if_needed = [&](unsigned c) {
if (round_to_bytes) {
unsigned bytes = (c + 7) / 8;
if ((bytes & (bytes - 1)) == 0) {
return bytes;
}
return 1u << num_significant_bits(bytes);
}
return c;
};
if (x.dtype().is_uint() || v > 0) { // v == 0 handled earlier
return DataType::UInt(round_if_needed(num_significant_bits(v)));
} else if (x.dtype().is_int()) {
if (v < 0) {
if (v == std::numeric_limits<decltype(v)>::min()) return x.dtype();
v = -v;
}
return DataType::Int(round_if_needed(1 + num_significant_bits(v)));
}
return x.dtype();
}
if (const auto* fpimm = x.as<FloatImmNode>()) {
if (fpimm->dtype.bits() == 64 && static_cast<double>(static_cast<float>(fpimm->value))) {
return DataType::Float(32);
}
}
return x.dtype();
}

} // namespace tir

// additional const expression overloading
Expand Down
20 changes: 15 additions & 5 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,22 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {

void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
ICHECK(range.defined());
if (tir::is_one(range->extent)) {
this->Bind(var, range->min, allow_override);

Optional<PrimExpr> maybe_min = tir::try_reset_expr_dtype(var.dtype(), range->min);
Optional<PrimExpr> maybe_extent = tir::try_reset_expr_dtype(var.dtype(), range->extent);

CHECK(maybe_min && maybe_extent)
<< "Incompatible types when binding a variable " << var << ':' << var.dtype()
<< " to a range min=" << range->min << ':' << range->min.dtype()
<< ", extent=" << range->extent << ':' << range->extent.dtype();

if (tir::is_one(maybe_extent.value())) {
this->Bind(var, maybe_min.value(), allow_override);
} else {
this->const_int_bound.Bind(var, range, allow_override);
this->int_set.Bind(var, range, allow_override);
this->transitive_comparisons.Bind(var, range, allow_override);
auto new_range = Range::FromMinExtent(maybe_min.value(), maybe_extent.value());
this->const_int_bound.Bind(var, new_range, allow_override);
this->int_set.Bind(var, new_range, allow_override);
this->transitive_comparisons.Bind(var, new_range, allow_override);
}
// skip modular_set
// skip rewrite simplify
Expand Down
23 changes: 17 additions & 6 deletions src/arith/transitive_comparison_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -573,19 +573,30 @@ void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range&
}
}

prev_bindings_.Set(var, range);
Optional<PrimExpr> maybe_min = tir::try_reset_expr_dtype(var.dtype(), range->min);
Optional<PrimExpr> maybe_extent = tir::try_reset_expr_dtype(var.dtype(), range->extent);

if (is_const_int(range->extent, 1)) {
AddKnown(var == range->min, &knowns_);
CHECK(maybe_min && maybe_extent)
<< "Incompatible types when binding a variable " << var << ':' << var.dtype()
<< " to a range min=" << range->min << ':' << range->min.dtype()
<< ", extent=" << range->extent << ':' << range->extent.dtype();

auto new_range = Range::FromMinExtent(maybe_min.value(), maybe_extent.value());
prev_bindings_.Set(var, new_range);

if (is_const_int(new_range->extent, 1)) {
AddKnown(var == new_range->min, &knowns_);
} else {
AddKnown(var >= range->min, &knowns_);
AddKnown(var < range->min + range->extent, &knowns_);
AddKnown(var >= new_range->min, &knowns_);
AddKnown(var < new_range->min + new_range->extent, &knowns_);
}
}

void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr,
bool allow_override) {
Bind(var, Range::FromMinExtent(expr, 1), allow_override);
if (expr.dtype().is_integer_type()) {
Bind(var, Range::FromMinExtent(expr, 1), allow_override);
}
}

std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) {
Expand Down
13 changes: 10 additions & 3 deletions src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,16 @@ IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag, Span span
<< "The dtype of the domain of an IterVar must be an integer type. However, the domain's "
"dtype is "
<< dom->extent.dtype();
CHECK_EQ(dom->extent.dtype(), var.dtype())
<< "The dtype of the extent of an IterVar (" << dom->extent.dtype()
<< ") must match its associated Var's dtype (" << var.dtype() << ")";

Optional<PrimExpr> maybe_min = tir::try_reset_expr_dtype(var.dtype(), dom->min);
Optional<PrimExpr> maybe_extent = tir::try_reset_expr_dtype(var.dtype(), dom->extent);

CHECK(maybe_min && maybe_extent)
<< "Incompatible types when binding a variable " << var << ':' << var.dtype()
<< " to a range min=" << dom->min << ':' << dom->min.dtype() << ", extent=" << dom->extent
<< ':' << dom->extent.dtype();

dom = Range::FromMinExtent(maybe_min.value(), maybe_extent.value());
}
n->dom = dom;
n->var = var;
Expand Down
23 changes: 8 additions & 15 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,22 +116,15 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
ICHECK(loop_var.dtype().is_scalar());
ICHECK(body.defined());

// When extent or min is an IntImm but has narrower dtype than loop_var, we directly promote them
// without raising errors.
auto try_promote_imm_dtype = [&](const PrimExpr& e) {
ICHECK(e.dtype().bits() <= loop_var.dtype().bits())
<< " Loop variable's dtype (" << loop_var.dtype()
<< ") is narrower than that of `min` or `extent` (" << e.dtype() << ")";
const IntImmNode* a = e.as<IntImmNode>();
if (a && e.dtype().bits() < loop_var.dtype().bits()) {
return make_const(loop_var.dtype(), a->value);
} else {
return e;
}
};
Optional<PrimExpr> maybe_min = tir::try_reset_expr_dtype(loop_var.dtype(), min);
Optional<PrimExpr> maybe_extent = tir::try_reset_expr_dtype(loop_var.dtype(), extent);

CHECK(maybe_min && maybe_extent) << "Incompatible types when binding a loop variable " << loop_var
<< ':' << loop_var.dtype() << " to a range min=" << min << ':'
<< min.dtype() << ", extent=" << extent << ':' << extent.dtype();

min = try_promote_imm_dtype(min);
extent = try_promote_imm_dtype(extent);
min = maybe_min.value();
extent = maybe_extent.value();

ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << min.dtype();
ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " << extent.dtype();
Expand Down