diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index ac7e879a644d..99f5adf93c49 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -119,6 +119,8 @@ class DataType { bool is_vector_bool() const { return is_vector() && bits() == 1; } /*! \return whether type is a Void type. */ bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; } + /* ! \return whether type is a signed or an unsigned int. */ + bool is_integer_type() const { return is_int() || is_uint(); } /*! * \brief Create a new data type by change lanes to a specified value. * \param lanes The target number of lanes. diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index ce4a4d6a2845..eee5e1f3ec6d 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -849,7 +849,6 @@ inline bool is_const_int(const PrimExpr& x); /*! * \brief Check whether x is an integer/float constant. - * \note This only return true for integer types. * \return whether x is constant */ inline bool is_const_number(const PrimExpr& x); @@ -882,6 +881,30 @@ inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array try_reset_expr_dtype(DataType t, const PrimExpr& x); + // Implementation details after this inline bool is_const_int(const PrimExpr& x) { return as_const_int(x); } @@ -973,6 +996,76 @@ inline PrimExpr make_zero(DataType t, Span span) { return make_const(t, 0, span); } +inline Optional try_reset_expr_dtype(DataType t, const PrimExpr& x) { + Optional none = NullOpt; + if (!is_const_number(x)) { + return x.dtype() == t ? x : none; + } + + DataType narrow_t = restricted_type(x, false); + if (narrow_t.code() == t.code()) { + return narrow_t.bits() <= t.bits() ? tvm::cast(t, x) : none; + } + if (narrow_t.is_int() && t.is_uint()) { + // Non-negative integer immediates will always have the restricted type of uint, + // so if the type is int, the immediate must be negative, hence not representable + // in an unsigned type. + return none; + } else if (narrow_t.is_uint() && t.is_int()) { + // The non-negative immediate can be switched to a signed type, if the signed type + // has at least one more bit. + return narrow_t.bits() < t.bits() ? tvm::cast(t, x) : none; + } + return none; +} + +inline DataType restricted_type(const PrimExpr& x, bool round_to_bytes) { + if (const int64_t* val = as_const_int(x)) { + int64_t v = *val; + if (x.dtype().is_integer_type() && (v == 0 || v == 1)) { + return DataType::Bool(); + } + auto num_significant_bits = [](uint64_t t) -> unsigned { + // Always return at least 1. + if (t == 0) return 1; +#ifdef __GNUC__ + return 64 - __builtin_clzll(t); +#else + for (int i = 8 * sizeof(t) - 1; i > 0; --i) { + if (((t << i) >> i) == t) return 64 - i; + } + return 64; +#endif + }; + auto round_if_needed = [&](unsigned c) { + if (round_to_bytes) { + unsigned bytes = (c + 7) / 8; + if ((bytes & (bytes - 1)) == 0) { + return bytes; + } + return 1u << num_significant_bits(bytes); + } + return c; + }; + if (x.dtype().is_uint() || v > 0) { // v == 0 handled earlier + return DataType::UInt(round_if_needed(num_significant_bits(v))); + } else if (x.dtype().is_int()) { + if (v < 0) { + if (v == std::numeric_limits::min()) return x.dtype(); + v = -v; + } + return DataType::Int(round_if_needed(1 + num_significant_bits(v))); + } + return x.dtype(); + } + if (const auto* fpimm = x.as()) { + if (fpimm->dtype.bits() == 64 && static_cast(static_cast(fpimm->value))) { + return DataType::Float(32); + } + } + return x.dtype(); +} + } // namespace tir // additional const expression overloading diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 3e5b8834ebca..8a58c61e3a52 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -53,12 +53,22 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { ICHECK(range.defined()); - if (tir::is_one(range->extent)) { - this->Bind(var, range->min, allow_override); + + Optional maybe_min = tir::try_reset_expr_dtype(var.dtype(), range->min); + Optional maybe_extent = tir::try_reset_expr_dtype(var.dtype(), range->extent); + + CHECK(maybe_min && maybe_extent) + << "Incompatible types when binding a variable " << var << ':' << var.dtype() + << " to a range min=" << range->min << ':' << range->min.dtype() + << ", extent=" << range->extent << ':' << range->extent.dtype(); + + if (tir::is_one(maybe_extent.value())) { + this->Bind(var, maybe_min.value(), allow_override); } else { - this->const_int_bound.Bind(var, range, allow_override); - this->int_set.Bind(var, range, allow_override); - this->transitive_comparisons.Bind(var, range, allow_override); + auto new_range = Range::FromMinExtent(maybe_min.value(), maybe_extent.value()); + this->const_int_bound.Bind(var, new_range, allow_override); + this->int_set.Bind(var, new_range, allow_override); + this->transitive_comparisons.Bind(var, new_range, allow_override); } // skip modular_set // skip rewrite simplify diff --git a/src/arith/transitive_comparison_analyzer.cc b/src/arith/transitive_comparison_analyzer.cc index 52010ec322c8..25d2f83287a6 100644 --- a/src/arith/transitive_comparison_analyzer.cc +++ b/src/arith/transitive_comparison_analyzer.cc @@ -573,19 +573,30 @@ void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& } } - prev_bindings_.Set(var, range); + Optional maybe_min = tir::try_reset_expr_dtype(var.dtype(), range->min); + Optional maybe_extent = tir::try_reset_expr_dtype(var.dtype(), range->extent); - if (is_const_int(range->extent, 1)) { - AddKnown(var == range->min, &knowns_); + CHECK(maybe_min && maybe_extent) + << "Incompatible types when binding a variable " << var << ':' << var.dtype() + << " to a range min=" << range->min << ':' << range->min.dtype() + << ", extent=" << range->extent << ':' << range->extent.dtype(); + + auto new_range = Range::FromMinExtent(maybe_min.value(), maybe_extent.value()); + prev_bindings_.Set(var, new_range); + + if (is_const_int(new_range->extent, 1)) { + AddKnown(var == new_range->min, &knowns_); } else { - AddKnown(var >= range->min, &knowns_); - AddKnown(var < range->min + range->extent, &knowns_); + AddKnown(var >= new_range->min, &knowns_); + AddKnown(var < new_range->min + new_range->extent, &knowns_); } } void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override) { - Bind(var, Range::FromMinExtent(expr, 1), allow_override); + if (expr.dtype().is_integer_type()) { + Bind(var, Range::FromMinExtent(expr, 1), allow_override); + } } std::function TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) { diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index d590f8b2dd8b..fd3daa8e6253 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -153,9 +153,16 @@ IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag, Span span << "The dtype of the domain of an IterVar must be an integer type. However, the domain's " "dtype is " << dom->extent.dtype(); - CHECK_EQ(dom->extent.dtype(), var.dtype()) - << "The dtype of the extent of an IterVar (" << dom->extent.dtype() - << ") must match its associated Var's dtype (" << var.dtype() << ")"; + + Optional maybe_min = tir::try_reset_expr_dtype(var.dtype(), dom->min); + Optional maybe_extent = tir::try_reset_expr_dtype(var.dtype(), dom->extent); + + CHECK(maybe_min && maybe_extent) + << "Incompatible types when binding a variable " << var << ':' << var.dtype() + << " to a range min=" << dom->min << ':' << dom->min.dtype() << ", extent=" << dom->extent + << ':' << dom->extent.dtype(); + + dom = Range::FromMinExtent(maybe_min.value(), maybe_extent.value()); } n->dom = dom; n->var = var; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 1d1e674a9dd1..6c4cd449f0ed 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -116,22 +116,15 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, ICHECK(loop_var.dtype().is_scalar()); ICHECK(body.defined()); - // When extent or min is an IntImm but has narrower dtype than loop_var, we directly promote them - // without raising errors. - auto try_promote_imm_dtype = [&](const PrimExpr& e) { - ICHECK(e.dtype().bits() <= loop_var.dtype().bits()) - << " Loop variable's dtype (" << loop_var.dtype() - << ") is narrower than that of `min` or `extent` (" << e.dtype() << ")"; - const IntImmNode* a = e.as(); - if (a && e.dtype().bits() < loop_var.dtype().bits()) { - return make_const(loop_var.dtype(), a->value); - } else { - return e; - } - }; + Optional maybe_min = tir::try_reset_expr_dtype(loop_var.dtype(), min); + Optional maybe_extent = tir::try_reset_expr_dtype(loop_var.dtype(), extent); + + CHECK(maybe_min && maybe_extent) << "Incompatible types when binding a loop variable " << loop_var + << ':' << loop_var.dtype() << " to a range min=" << min << ':' + << min.dtype() << ", extent=" << extent << ':' << extent.dtype(); - min = try_promote_imm_dtype(min); - extent = try_promote_imm_dtype(extent); + min = maybe_min.value(); + extent = maybe_extent.value(); ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << min.dtype(); ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " << extent.dtype();